Skip to content

Commit 160d3ad

Browse files
Fix losses and aggregators when using CUDA graphs (#280)
* Pass step as scalar tensor for CUDA graph compatability Signed-off-by: Jason Ye <jasonyecanada@gmail.com> * Update turbulent channel example aggregator for CUDA graphs Signed-off-by: Jason Ye <jasonyecanada@gmail.com> * Update aggregator tests Signed-off-by: Jason Ye <jasonyecanada@gmail.com> * Fix formatting issues * Fix formatting issues * Fix formatting issues Refactor forward method signature for better readability. * Fix formatting issues --------- Signed-off-by: Jason Ye <jasonyecanada@gmail.com> Co-authored-by: Kaustubh Tangsali <71059996+ktangsali@users.noreply.github.com>
1 parent a863d0d commit 160d3ad

File tree

11 files changed

+187
-132
lines changed

11 files changed

+187
-132
lines changed

examples/turbulent_channel/2d/custom_aggregator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ class CustomSum(Aggregator):
3030
def __init__(self, params, num_losses, weights=None):
3131
super().__init__(params, num_losses, weights)
3232

33-
def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
33+
def forward(
34+
self, losses: Dict[str, torch.Tensor], step: torch.Tensor
35+
) -> torch.Tensor:
3436
"""
3537
Aggregates the losses by summation
3638
3739
Parameters
3840
----------
3941
losses : Dict[str, torch.Tensor]
4042
A dictionary of losses
41-
step : int
43+
step : torch.Tensor
4244
Optimizer step
4345
4446
Returns
@@ -54,8 +56,7 @@ def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
5456
loss: torch.Tensor = torch.zeros_like(self.init_loss)
5557

5658
smoothness = 0.0005 # use 0.0005 to smoothen the transition over ~10k steps
57-
step_tensor = torch.tensor(step, dtype=torch.float32)
58-
decay_weight = (torch.tanh((20000 - step_tensor) * smoothness) + 1.0) * 0.5
59+
decay_weight = (torch.tanh((20000 - step.float()) * smoothness) + 1.0) * 0.5
5960

6061
# Add losses
6162
for key in losses.keys():

examples/turbulent_channel/2d_std_wf/custom_aggregator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ class CustomSum(Aggregator):
3030
def __init__(self, params, num_losses, weights=None):
3131
super().__init__(params, num_losses, weights)
3232

33-
def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
33+
def forward(
34+
self, losses: Dict[str, torch.Tensor], step: torch.Tensor
35+
) -> torch.Tensor:
3436
"""
3537
Aggregates the losses by summation
3638
3739
Parameters
3840
----------
3941
losses : Dict[str, torch.Tensor]
4042
A dictionary of losses
41-
step : int
43+
step : torch.Tensor
4244
Optimizer step
4345
4446
Returns
@@ -54,8 +56,7 @@ def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
5456
loss: torch.Tensor = torch.zeros_like(self.init_loss)
5557

5658
smoothness = 0.0005 # use 0.0005 to smoothen the transition over ~10k steps
57-
step_tensor = torch.tensor(step, dtype=torch.float32)
58-
decay_weight = (torch.tanh((20000 - step_tensor) * smoothness) + 1.0) * 0.5
59+
decay_weight = (torch.tanh((20000 - step.float()) * smoothness) + 1.0) * 0.5
5960

6061
# Add losses
6162
for key in losses.keys():

0 commit comments

Comments
 (0)