Skip to content

(pytorch) torch.nn.SyncBatchNorm is not supported and its existence in a model will cause an exception during use of smdistributed #2571

@dwhite54

Description

@dwhite54

This could easily be considered a bug report, but it's hard to call this your bug.

What did you find confusing? Please describe.
There is no mention of torch.nn.SyncBatchNorm, and its potential to (confusingly) break a SM training job.

Describe how documentation can be improved
Mention that initializing smdistributed with init_process_group will not allow use of certain other torch features, such as (and likely not limited to) torch.nn.SyncBatchNorm.

Additional context
torch.nn.SyncBatchNorm calls torch.distributed.get_world_size (not smdistributed...get_world_size), which causes an exception since torch.distributed hasn't been initialized.

An alternative to this would be to scan the estimator source for mentions of SyncBatchNorm. Another alternative would be to somehow get torch.nn to import smdistributed.dataparallel.torch.distributed. Even better would be a new SM API replacement for torch.nn, but that seems excessive!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions