-
Notifications
You must be signed in to change notification settings - Fork 107
Open
Description
In this case, the output is always zeros regardless of the input. Because after the max op, dim 1 only has one element and x - x.mean is always 0.
KernelBench/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py
Lines 13 to 25 in 768d52c
| def forward(self, x): | |
| """ | |
| Args: | |
| x: Input tensor of shape (batch_size, in_features) | |
| Returns: | |
| Output tensor of shape (batch_size, out_features) | |
| """ | |
| x = self.gemm(x) | |
| x = torch.max(x, dim=self.max_dim, keepdim=True).values | |
| x = x - x.mean(dim=1, keepdim=True) | |
| x = torch.nn.functional.gelu(x) | |
| return x |
Metadata
Metadata
Assignees
Labels
No labels