Skip to content

Commit 0714ac4

Browse files
committed
rmsnorm bwd: make blocking math more robust
ghstack-source-id: 1b9366a Pull-Request: #532
1 parent 496c120 commit 0714ac4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tritonbench/operators/rms_norm/fused_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def backward(ctx, dy):
106106

107107
M, N = x_arg.shape
108108
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
109-
BLOCK_SIZE_M = min(2048, triton.next_power_of_2(M // (8 * NUM_SMS)))
110-
PARTIAL_SIZE = math.ceil(M / BLOCK_SIZE_M)
109+
BLOCK_SIZE_M = min(2048, triton.next_power_of_2(triton.cdiv(M, (8 * NUM_SMS))))
110+
PARTIAL_SIZE = triton.cdiv(M, BLOCK_SIZE_M)
111111

112112
# Columnwise stride for reducing partial sums at end, contiguous loads
113113
_dw = torch.empty((PARTIAL_SIZE, N), dtype=w.dtype, device=w.device)

0 commit comments

Comments
 (0)