Skip to content

Commit e6375d0

Browse files
committed
upload training code
1 parent 64413c3 commit e6375d0

24 files changed

+3799
-395
lines changed

accelerate_configs/01234567.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: MULTI_GPU
4+
downcast_bf16: 'no'
5+
gpu_ids: '0,1,2,3,4,5,6,7'
6+
machine_rank: 0
7+
main_training_function: main
8+
mixed_precision: 'no'
9+
num_machines: 1
10+
num_processes: 8
11+
rdzv_backend: static
12+
same_network: true
13+
tpu_env: []
14+
tpu_use_cluster: false
15+
tpu_use_sudo: false
16+
use_cpu: false

datasets/quick_validation/00.png

904 KB
Loading

datasets/quick_validation/01.jpg

80.9 KB
Loading

datasets/quick_validation/02.jpg

55.5 KB
Loading

datasets/quick_validation/03.jpg

285 KB
Loading

datasets/quick_validation/04.jpg

59 KB
Loading

train_lotus_d.py

Lines changed: 1150 additions & 0 deletions
Large diffs are not rendered by default.

train_lotus_g.py

Lines changed: 1201 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# export PYTHONPATH="$(dirname "$(dirname "$0")"):$PYTHONPATH"
2+
3+
export MODEL_NAME="stabilityai/stable-diffusion-2-base"
4+
5+
# training dataset
6+
export TRAIN_DATA_DIR_HYPERSIM=$PATH_TO_HYPERSIM_DATA
7+
export TRAIN_DATA_DIR_VKITTI=$PATH_TO_VKITTI_DATA
8+
export RES_HYPERSIM=576
9+
export RES_VKITTI=375
10+
export P_HYPERSIM=0.9
11+
export NORMTYPE="trunc_disparity"
12+
13+
# training configs
14+
export BATCH_SIZE=16
15+
export CUDA=01234567
16+
export GAS=1
17+
export TOTAL_BSZ=$(($BATCH_SIZE * ${#CUDA} * $GAS))
18+
19+
# model configs
20+
export TIMESTEP=999
21+
export TASK_NAME="depth"
22+
23+
# eval
24+
export BASE_TEST_DATA_DIR="datasets/eval/"
25+
export VALIDATION_IMAGES="datasets/quick_validation/"
26+
export VAL_STEP=500
27+
28+
# output dir
29+
export OUTPUT_DIR="output/train-lotus-d-${TASK_NAME}-bsz${TOTAL_BSZ}/"
30+
31+
accelerate launch --config_file=accelerate_configs/$CUDA.yaml --mixed_precision="fp16" \
32+
--main_process_port="13324" \
33+
train_lotus_d.py \
34+
--pretrained_model_name_or_path=$MODEL_NAME \
35+
--train_data_dir_hypersim=$TRAIN_DATA_DIR_HYPERSIM \
36+
--resolution_hypersim=$RES_HYPERSIM \
37+
--train_data_dir_vkitti=$TRAIN_DATA_DIR_VKITTI \
38+
--resolution_vkitti=$RES_VKITTI \
39+
--prob_hypersim=$P_HYPERSIM \
40+
--mix_dataset \
41+
--random_flip \
42+
--norm_type=$NORMTYPE \
43+
--dataloader_num_workers=0 \
44+
--train_batch_size=$BATCH_SIZE \
45+
--gradient_accumulation_steps=$GAS \
46+
--gradient_checkpointing \
47+
--max_grad_norm=1 \
48+
--seed=42 \
49+
--max_train_steps=20000 \
50+
--learning_rate=3e-05 \
51+
--lr_scheduler="constant" --lr_warmup_steps=0 \
52+
--task_name=$TASK_NAME \
53+
--timestep=$TIMESTEP \
54+
--validation_images=$VALIDATION_IMAGES \
55+
--validation_steps=$VAL_STEP \
56+
--checkpointing_steps=$VAL_STEP \
57+
--base_test_data_dir=$BASE_TEST_DATA_DIR \
58+
--output_dir=$OUTPUT_DIR \
59+
--resume_from_checkpoint="latest"
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# export PYTHONPATH="$(dirname "$(dirname "$0")"):$PYTHONPATH"
2+
3+
export MODEL_NAME="stabilityai/stable-diffusion-2-base"
4+
5+
# training dataset
6+
export TRAIN_DATA_DIR_HYPERSIM=$PATH_TO_HYPERSIM_DATA
7+
export TRAIN_DATA_DIR_VKITTI=$PATH_TO_VKITTI_DATA
8+
export RES_HYPERSIM=576
9+
export RES_VKITTI=375
10+
export P_HYPERSIM=0.9
11+
12+
# training configs
13+
export BATCH_SIZE=16
14+
export CUDA=01234567
15+
export GAS=1
16+
export TOTAL_BSZ=$(($BATCH_SIZE * ${#CUDA} * $GAS))
17+
18+
# model configs
19+
export TIMESTEP=999
20+
export TASK_NAME="normal"
21+
22+
# eval
23+
export BASE_TEST_DATA_DIR="datasets/eval/"
24+
export VALIDATION_IMAGES="datasets/quick_validation/"
25+
export VAL_STEP=500
26+
27+
# output dir
28+
export OUTPUT_DIR="output/train-lotus-d-${TASK_NAME}-bsz${TOTAL_BSZ}/"
29+
30+
accelerate launch --config_file=accelerate_configs/$CUDA.yaml --mixed_precision="fp16" \
31+
--main_process_port="13324" \
32+
train_lotus_d.py \
33+
--pretrained_model_name_or_path=$MODEL_NAME \
34+
--train_data_dir_hypersim=$TRAIN_DATA_DIR_HYPERSIM \
35+
--resolution_hypersim=$RES_HYPERSIM \
36+
--train_data_dir_vkitti=$TRAIN_DATA_DIR_VKITTI \
37+
--resolution_vkitti=$RES_VKITTI \
38+
--prob_hypersim=$P_HYPERSIM \
39+
--mix_dataset \
40+
--random_flip \
41+
--align_cam_normal \
42+
--dataloader_num_workers=0 \
43+
--train_batch_size=$BATCH_SIZE \
44+
--gradient_accumulation_steps=$GAS \
45+
--gradient_checkpointing \
46+
--max_grad_norm=1 \
47+
--seed=42 \
48+
--max_train_steps=20000 \
49+
--learning_rate=3e-05 \
50+
--lr_scheduler="constant" --lr_warmup_steps=0 \
51+
--task_name=$TASK_NAME \
52+
--timestep=$TIMESTEP \
53+
--validation_images=$VALIDATION_IMAGES \
54+
--validation_steps=$VAL_STEP \
55+
--checkpointing_steps=$VAL_STEP \
56+
--base_test_data_dir=$BASE_TEST_DATA_DIR \
57+
--output_dir=$OUTPUT_DIR \
58+
--resume_from_checkpoint="latest"

0 commit comments

Comments
 (0)