-
Notifications
You must be signed in to change notification settings - Fork 2
Dit unit tests #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dit unit tests #68
Changes from 10 commits
f976ccf
8f324b5
d271bd6
2f46b43
728250e
fa1b884
dee1153
644970d
a87268b
3e93c2a
4723028
79f98d1
f859fbb
99ada83
f007a7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,11 @@ def __post_init__(self): | |
| self.sequence_length = self.dataset.seq_length | ||
|
|
||
| def build_datasets(self, context: DatasetBuildContext): | ||
| return self.dataset.train_dataloader(), self.dataset.val_dataloader(), self.dataset.test_dataloader() | ||
| return ( | ||
| iter(self.dataset.train_dataloader()), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice improvement! Wrapping dataloaders with iter() makes the interface more explicit and reduces potential confusion in downstream usage. |
||
| iter(self.dataset.val_dataloader()), | ||
| iter(self.dataset.val_dataloader()), | ||
| ) | ||
|
|
||
|
|
||
| class DiffusionDataModule(EnergonMultiModalDataModule): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
|
|
||
| from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig | ||
| from dfm.src.megatron.data.dit.dit_mock_datamodule import DiTMockDataModuleConfig | ||
| from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider | ||
| from dfm.src.megatron.model.dit.dit_model_provider import DiTModelProvider, DiTXLModelProvider | ||
|
|
||
|
|
||
| def model_config( | ||
|
|
@@ -57,7 +57,7 @@ def model_config( | |
| Returns: | ||
| DiTModelProvider: Configuration for the DiT-S model. | ||
| """ | ||
| return DiTModelProvider( | ||
| return DiTXLModelProvider( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: In docstring it notes "DiT-S" but we provide DiTXLModelProvider. |
||
| tensor_model_parallel_size=tensor_parallelism, | ||
| pipeline_model_parallel_size=pipeline_parallelism, | ||
| pipeline_dtype=pipeline_parallelism_dtype, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good cleanup! Removing debug print statements keeps the output clean in production.