Effective Cloud Removal for Remote Sensing Images by an Improved Mean-Reverting Denoising Model with Elucidated Design Space
这是EMRDM的官方实现仓库。EMRDM是一种新型的去云方法,其基于改进的均值回归的扩散模型。相比于以往的均值回归扩散模型,EMRDM拥有模块化的架构,该架构提供了一个灵活而清晰的设计空间,该架构的模块也经过重新优化设计。经过这些改进,EMRDM在单时态和多时态去云任务的公开数据集上都取得了SOTA的效果。
- 🎯我们发布了EMRDM的官方实现;
- 🎯我们发布了在CUHK-CR1,CUHK-CR2,SEN12MS-CR和Sen2_MTC_New四个数据集上训练的EMRDM模型的权重;
- 🎯我们提供了一份中文的readme文件;
- 🎯我们在本仓库提供了引用我们的文章的bibtex范例。
我们在本项目的根目录下面提供了requirements.txt,我们在虚拟环境里面使用的所有包的版本都写在了这个文件里面,但是不推荐直接用pip install -r requirements.txt从这个文件里面安装所有的包,因为这些包的依赖关系非常复杂,必须按照一定的顺序安装,且这个文件里面写的一些包是冗余的。推荐按照我们下面提供的详细步骤来安装包。
首先我们推荐安装一些比较大的且比较关键的包,例如torch,flash_attn,natten和pytorch-lightning。在安装这些包之前,可以使用conda来管理虚拟环境,conda创建虚拟环境的命令是:
conda create --name emrdm python=3.10
conda activate emrdm然后,使用下面的命令来安装torch,顺便指定一下numpy的版本,把numpy也下载下来:
pip install torch==2.2.1 torchaudio==2.2.1 torchvision==0.17.1 numpy==1.26.4需要注意的是,我们使用的CUDA版本是CUDA 12.1,这个版本信息可能在安装torch的时候有用。
接下来可以用下面的命令安装flash_attn:
MAX_JOBS=4 pip install flash_attn==2.5.9.post1 --no-build-isolation但是其实这条命令会从头编译flash_attn,非常消耗时间。为了加速安装,可以从flash_attn的官方仓库下载一下wheel文件(注意是2.5.9.post1版本的wheel文件)然后使用下面的命令从wheel文件安装包:
pip install [name_of_flash_attn_wheel].whl之后,要下载natten包,可以用如下命令:
pip install natten==0.17.1同样这条命令会从头编译natten,很耗时间。你也可以到natten的官方仓库下载wheel文件来加速。但是natten官方将他们的wheel文件放到了一个网站上,所以你可以通过pip -f指定wheel所在的网站来进行natten的下载,如下:
pip install natten==0.17.1+torch220cu121 -f https://shi-labs.com/natten/wheels下载pytorch-lightning,可以使用如下这条命令:
pip install pytorch-lightning==2.3.0在上面这些关键的包下载完之后,你可以用下面这些命令下载其他的辅助的包:
pip install wandb==0.17.8 matplotlib==3.9.2 natsort==8.4.0 \
omegaconf==2.3.0 scipy==1.14.0 dctorch==0.1.2 rasterio==1.3.11
pip install pandas==2.2.3 opencv-python==4.10.0.84 lpips==0.1.4
pip install tifffile==2024.7.24 s2cloudless==1.7.2 \
albumentations==1.4.10 albucore==0.0.12(可选)如果你还是碰到了包丢失的问题,你可以参考requirements.txt文件来下载你缺失的包,这个文件里面提供了我们使用的虚拟环境里面所有包的版本信息。
如果你还遇到了其他的上面没提到的环境配置问题,请和我们联系并将你遇到的问题报告给我们,或者在github上发一个issue。
我们使用了四个数据集:CUHK-CR1,CUHK-CR2,SEN12MS-CR和Sen2_MTC_New。你需要先把这些数据集下载好。
下面我们提供了下载这些数据集的网址:
| 数据集 | 类型 | 网址 |
|---|---|---|
| CUHK-CR1 | 单时态 | https://github.com/littlebeen/DDPM-Enhancement-for-Cloud-Removal |
| CUHK-CR2 | 单时态 | https://github.com/littlebeen/DDPM-Enhancement-for-Cloud-Removal |
| SEN12MS-CR | 单时态 | https://patricktum.github.io/cloud_removal/sen12mscr/ |
| Sen2_MTC_New | 多时态 | https://github.com/come880412/CTGAN |
如果你想快点把代码跑起来,你可以先下载一下测试数据集,然后跑下面提供的测试命令。
在./configs/example_training/路径下面,我们提供了配置文件,也就是*.yaml文件。代码会自动读取yaml文件并设置代码运行的配置信息。你可以在yaml文件中修改这些配置,例如数据路径,batch size,读取数据的worker数目这些数据集的配置可以在yaml文件里面data这部分修改。如果想知道更详细的配置方法,可以看一下./configs/example_training/路径下面提供的范例yaml文件
我们也在./configs/example_training/ablation/路径下提供了我们进行消融实验时使用的配置文件,如果你有兴趣可以看一看。
在本项目的根路径下,使用下面的命令可以进行模型的训练:
python main.py --base configs/example_training/[yaml_file_name].yaml --enable_tf32我们在./configs/example_training/提供了四个数据集训练的配置文件,cuhk.yaml是用来在CUHK-CR1上训练的配置文件,cuhkv2.yaml是用来在CUHK-CR2上训练的配置文件,sen2_mtc_new.yaml 是用来在Sen2_MTC_New数据集上训练的配置文件,sentinel.yaml是用来在SEN12MS-CR数据集上训练的配置文件。注意:你需要修改一下这些yaml文件里面data.params.train部分的内容来适应你自己的数据集的配置(例如修改数据所在路径)。
你也可以在命令行上用-l参数来改变训练过程中日志文件的保存位置,不修改的话./logs就是默认的保存位置:
python main.py --base configs/example_training/[yaml_file_name].yaml --enable_tf32 -l [path_to_your_logs]如果你想从一个之前训练好的检查点(checkpoint)继续训练,你可以使用命令行参数-r指定检查点的位置,如下:
python main.py --base configs/example_training/[yaml_file_name].yaml --enable_tf32 -r [path_to_your_ckpt]如果你想从一个已训练好的检查点初始化模型,但是重启训练过程(即不从上次训练的中止epoch继续训练),你可以将yaml文件中的model.ckpt_path的值改为你的检查点的路径。
测试使用下面的命令:
python main.py --base configs/example_training/[yaml_file_name].yaml --enable_tf32 -t false[yaml_file_name].yaml可以使用和训练过程中一样的yaml文件,但是需要做如下一些修改:
- 你需要设置
yaml里面的data.params.test部分,以适配你的测试数据集的信息,否则测试集的dataloader将不会被实现,测试过程也就不会进行。 - 你需要修改
yaml里面model.ckpt_path的值,将之改为你训练好的(或者是本项目提供的)检查点的路径。我们提供了一些已经训练好的EMRDM的检查点,可用于测试,请参考模型一节。
预测过程会输出所有预测的去云图像(确保你有足够的硬盘空间),这个过程目前只支持单GPU运行,你需要修改yaml文件中的lightning.trainer.devices部分,只能设置一个设备。然后通过下面的命令运行预测过程:
python main.py --base configs/example_training/[yaml_file_name].yaml --enable_tf32 -t false --no-test true --predict true[yaml_file_name].yaml和测试过程使用的是一样的。注意你还需要设置一下yaml文件里面的data.params.predict部分以及model.ckpt_path的值,设置方法和训练过程一样,不设置的话将无法得到正确结果。
我们公布了所有的训练结果,包括训练的配置文件,训练得到的模型权重,还有训练过程的日志文件。
我们也公布了所有的测试结果,包括测试的配置文件,测试过程的日志文件。
所有的文件都可以通过下表展示的网址来下载:
| 谷歌云盘 | 阿里云盘 | 百度云盘 |
|---|---|---|
| https://drive.google.com/drive/folders/1T3OwRNP5r5qVLQZujnl2WDBVXHC1Am65?usp=sharing | https://www.alipan.com/s/39BcJezgsBC | https://pan.baidu.com/s/1RqYgluNNcYKXOa33kQioMQ |
在这些共享文件里面,train文件夹下放的是训练结果,test文件夹下放的是测试结果。百度云盘的分享码是6161。
The code is based on the official implementations of the generative-models, k-diffusion, utae-paps and other repositories, as follows:
本仓库的代码主要基于generative-models,k-diffusion,utae-paps和其他一些以往的开源仓库的实现,我们将这些开源仓库和我们的项目的关系总结如下:
| 仓库 | 网址 | 和本仓库的关系 |
|---|---|---|
generative-models |
https://github.com/Stability-AI/generative-models | 我们的代码的核心架构是基于该仓库。 |
k-diffusion |
https://github.com/crowsonkb/k-diffusion | 我们的代码中的去噪神经网络的实现是基于该仓库。 |
utae-paps |
https://github.com/VSainteuf/utae-paps | 我们的时序融合注意力机制的实现是基于该仓库。 |
UnCRtainTS |
https://github.com/PatrickTUM/UnCRtainTS | 我们代码中的SEN12MS-CR数据集的dataloader的实现是基于该仓库。 |
SEN12MS-CR-TS |
https://github.com/PatrickTUM/SEN12MS-CR-TS | 我们代码中的SEN12MS-CR数据集的dataloader的实现是基于该仓库。 |
CTGAN |
https://github.com/come880412/CTGAN | 我们代码中的Sen2_MTC_New数据集的dataloader的实现是基于该仓库。 |
DDPM-Enhancement-for-Cloud-Removal |
https://github.com/littlebeen/DDPM-Enhancement-for-Cloud-Removal | 我们代码中的CUHK-CR1和CUHK-CR2数据集的dataloader的实现是基于该仓库。 |
我们向这些仓库的开发者致以诚挚的感谢,感谢他们对开源社区做出的巨大贡献,没有他们的项目就不会有本仓库。
如果你遇到了任何问题,任何时候都可以通过下面两个邮箱联系我:liuyi2052697@foxmail.com (推荐)和liuyi61@tongji.edu.cn。只要我有时间,我都会尽力回答你的问题。
@inproceedings{liu2025effective,
title={Effective cloud removal for remote sensing images by an improved mean-reverting denoising model with elucidated design space},
author={Liu, Yi and Li, Wengen and Guan, Jihong and Zhou, Shuigeng and Zhang, Yichao},
booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
pages={17851--17861},
year={2025}
}