Skip to content

Commit 8424b81

Browse files
committed
sovled the DDP problem
1 parent f986181 commit 8424b81

File tree

11 files changed

+20051
-34169
lines changed

11 files changed

+20051
-34169
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ FLAME/
2323
# ignore the data
2424
data/HDTF_TFHP
2525
data/MNIST
26+
data/VOCASET
2627
data/data_pipline/audio_visual_dataset/

base/base_config.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,19 @@ def __init__(self):
102102
###########################
103103
cfg.MODEL = CN()
104104
cfg.MODEL.NAME = ""
105+
cfg.MODEL.INIT_WEIGHTS = "" # Path to model weights (for initialization)
106+
cfg.MODEL.AUDIO_MODEL = 'wav2vec2'
107+
cfg.MODEL.AUDIO_DIM = 128
108+
105109
cfg.MODEL.MLP = CN()
106110
cfg.MODEL.MLP.INPUT_DIM = 784
107111
cfg.MODEL.MLP.HIDDEN_DIM = [128, 64]
108112
cfg.MODEL.MLP.OUTPUT_DIM = 10
109-
# Path to model weights (for initialization)
110-
cfg.MODEL.INIT_WEIGHTS = ""
113+
114+
115+
116+
117+
111118
# Definition of embedding layers
112119
cfg.MODEL.HEAD = CN()
113120
# If none, do not construct embedding layers, the
@@ -119,16 +126,10 @@ def __init__(self):
119126
cfg.MODEL.HEAD.ACTIVATION = "relu"
120127
cfg.MODEL.HEAD.BN = True
121128
cfg.MODEL.HEAD.DROPOUT = 0.0
122-
# VQ-VAE config
123-
cfg.MODEL.HEAD.N_EMBED = 256
124-
cfg.MODEL.HEAD.ZQUANT_DIM = 64
125-
# Audio model
126-
cfg.MODEL.HEAD.AUDIO_MODEL = 'wav2vec2'
127-
cfg.MODEL.HEAD.AUDIO_DIM = 128
128-
# Style ref
129-
cfg.MODEL.HEAD.STYLE_DIM = 128
130-
# Use indicator for padding frames
131-
cfg.MODEL.HEAD.USE_INDICATOR = False
129+
130+
131+
132+
cfg.MODEL.HEAD.USE_INDICATOR = False # Use indicator for padding frames
132133

133134
# optional head type according to different input
134135
cfg.MODEL.HEAD.ROT_REPR = 'aa'

base/base_trainer.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
def build_trainer(cfg):
3333
avai_trainers = TRAINER_REGISTRY.registered_names()
3434
check_availability(cfg.TRAINER.NAME, avai_trainers)
35-
if cfg.ENV.VERBOSE:
36-
print("Loading trainer: {}".format(cfg.TRAINER.NAME))
3735
return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)
3836

3937
class TrainerBase:
@@ -98,21 +96,17 @@ def system_init(self):
9896
## cuda setting
9997
if torch.cuda.is_available() and self.cfg.ENV.USE_CUDA:
10098
torch.backends.cudnn.benchmark = True
101-
gpu_ids = self.cfg.ENV.GPU
102-
if not gpu_ids:
103-
raise ValueError("ENV.GPU must contain at least one gpu id when USE_CUDA=True")
104-
10599
if self.is_distributed:
106100
# In distributed mode, use local_rank to determine GPU
107-
target_gpu = gpu_ids[self.local_rank % len(gpu_ids)]
101+
target_gpu = self.cfg.ENV.GPU[self.local_rank % len(self.cfg.ENV.GPU)]
108102
else:
109-
target_gpu = gpu_ids[0]
110-
if len(gpu_ids) > 1 and torch.distributed.is_available():
103+
target_gpu = self.cfg.ENV.GPU[0]
104+
if len(self.cfg.ENV.GPU) > 1 and torch.distributed.is_available():
111105
# assume torchrun/launch supplies LOCAL_RANK; fallback to rank % len(gpu_ids)
112106
local_rank = int(os.environ.get("LOCAL_RANK", 0))
113107
if torch.distributed.is_initialized():
114-
local_rank = torch.distributed.get_rank() % len(gpu_ids)
115-
target_gpu = gpu_ids[local_rank % len(gpu_ids)]
108+
local_rank = torch.distributed.get_rank() % len(self.cfg.ENV.GPU)
109+
target_gpu = self.cfg.ENV.GPU[local_rank % len(self.cfg.ENV.GPU)]
116110

117111
self.device = torch.device(f"cuda:{target_gpu}")
118112
torch.cuda.set_device(self.device)
@@ -126,17 +120,11 @@ def _init_distributed(self):
126120
# Get local rank from environment variable (set by torchrun)
127121
self.local_rank = int(os.environ.get('LOCAL_RANK', -1))
128122

129-
if self.local_rank == -1:
130-
print("LOCAL_RANK not found in environment. Falling back to non-distributed mode.")
131-
self.cfg.ENV.DISTRIBUTED = False
132-
return
133-
134123
# Initialize process group
135124
dist.init_process_group(
136125
backend=self.cfg.ENV.DIST_BACKEND,
137126
init_method=self.cfg.ENV.DIST_URL
138127
)
139-
140128
self.rank = dist.get_rank()
141129
self.world_size = dist.get_world_size()
142130
self.is_distributed = True

config/difftalk_trainer_config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ ENV:
1414
TAGS: [Baseline]
1515
MODE: online
1616
EXTRA:
17-
STYLE_ENC_CKPT:
17+
STYLE_DIM: 128
18+
STYLE_ENC_CKPT:
1819

1920
DATASET:
2021
NAME: HDTF_TFHP

config/toy_trainer_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
ENV:
44
SEED: 42
55
OUTPUT_DIR: ./output
6-
GPU: [0] # Multi-GPU training
6+
GPU: [0, 1] # List format - will be parsed as a list
77
USE_CUDA: True
88
VERBOSE: True
99

dataset/MINIST.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
MNIST Dataset for Handwritten Digit Recognition
33
"""
4-
import os
54
from torchvision import datasets, transforms
65

76
from base.base_dataset import Datum, DatasetBase, DATASET_REGISTRY

0 commit comments

Comments
 (0)