diff --git a/tritonbench/operators/rms_norm/fused_triton.py b/tritonbench/operators/rms_norm/fused_triton.py index 5a8f1472..5ea3fb3a 100644 --- a/tritonbench/operators/rms_norm/fused_triton.py +++ b/tritonbench/operators/rms_norm/fused_triton.py @@ -106,8 +106,8 @@ def backward(ctx, dy): M, N = x_arg.shape NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - BLOCK_SIZE_M = min(2048, triton.next_power_of_2(M // (8 * NUM_SMS))) - PARTIAL_SIZE = math.ceil(M / BLOCK_SIZE_M) + BLOCK_SIZE_M = min(2048, triton.next_power_of_2(triton.cdiv(M, (8 * NUM_SMS)))) + PARTIAL_SIZE = triton.cdiv(M, BLOCK_SIZE_M) # Columnwise stride for reducing partial sums at end, contiguous loads _dw = torch.empty((PARTIAL_SIZE, N), dtype=w.dtype, device=w.device)