Skip to content
7 changes: 3 additions & 4 deletions examples/turbulent_channel/2d/custom_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ class CustomSum(Aggregator):
def __init__(self, params, num_losses, weights=None):
super().__init__(params, num_losses, weights)

def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
def forward(self, losses: Dict[str, torch.Tensor], step: torch.Tensor) -> torch.Tensor:
"""
Aggregates the losses by summation

Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses
step : int
step : torch.Tensor
Optimizer step

Returns
Expand All @@ -54,8 +54,7 @@ def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
loss: torch.Tensor = torch.zeros_like(self.init_loss)

smoothness = 0.0005 # use 0.0005 to smoothen the transition over ~10k steps
step_tensor = torch.tensor(step, dtype=torch.float32)
decay_weight = (torch.tanh((20000 - step_tensor) * smoothness) + 1.0) * 0.5
decay_weight = (torch.tanh((20000 - step.float()) * smoothness) + 1.0) * 0.5

# Add losses
for key in losses.keys():
Expand Down
7 changes: 3 additions & 4 deletions examples/turbulent_channel/2d_std_wf/custom_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ class CustomSum(Aggregator):
def __init__(self, params, num_losses, weights=None):
super().__init__(params, num_losses, weights)

def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
def forward(self, losses: Dict[str, torch.Tensor], step: torch.Tensor) -> torch.Tensor:
"""
Aggregates the losses by summation

Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses
step : int
step : torch.Tensor
Optimizer step

Returns
Expand All @@ -54,8 +54,7 @@ def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
loss: torch.Tensor = torch.zeros_like(self.init_loss)

smoothness = 0.0005 # use 0.0005 to smoothen the transition over ~10k steps
step_tensor = torch.tensor(step, dtype=torch.float32)
decay_weight = (torch.tanh((20000 - step_tensor) * smoothness) + 1.0) * 0.5
decay_weight = (torch.tanh((20000 - step.float()) * smoothness) + 1.0) * 0.5

# Add losses
for key in losses.keys():
Expand Down
Loading