Skip to content

Commit 11515d0

Browse files
committed
use bf16 for amp
1 parent b048a2d commit 11515d0

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

cosyvoice/utils/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
166166
for k, v in info_dict['loss_dict'].items():
167167
if k not in total_loss_dict:
168168
total_loss_dict[k] = []
169-
total_loss_dict[k].append(v.item() * num_utts)
169+
total_loss_dict[k].append(v.mean().item() * num_utts)
170170
log_per_step(None, info_dict)
171171
for k, v in total_loss_dict.items():
172172
total_loss_dict[k] = sum(v) / total_num_utts

cosyvoice/utils/train_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def init_dataset_and_dataloader(args, configs, gan, dpo):
7171

7272
def check_modify_and_save_config(args, configs):
7373
if args.train_engine == "torch_ddp":
74-
configs['train_conf']["dtype"] = 'fp32'
74+
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
7575
else:
7676
with open(args.deepspeed_config, 'r') as fin:
7777
ds_configs = json.load(fin)
@@ -247,7 +247,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
247247
dtype = torch.float32
248248

249249
if info_dict['train_engine'] == 'torch_ddp':
250-
autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
250+
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
251251
else:
252252
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
253253

0 commit comments

Comments
 (0)