|
312 | 312 | allow_tf32=ALLOW_TF32, |
313 | 313 | ) |
314 | 314 |
|
315 | | - {% if ki == k_tiles - 1 %} |
316 | | - # rematerialize rm and rn to save registers |
317 | | - rcm = rm + tl.arange(0, BLOCK_M) |
318 | | - rcn = rn + tl.arange(0, BLOCK_N) |
319 | | - idx_m = rcm[:, None] |
320 | | - idx_n = rcn[None, :] |
321 | | - mask = (idx_m < M) & (idx_n < N) |
322 | | -
|
323 | | - # inductor generates a suffix |
324 | | - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} |
325 | | - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
326 | | - {% endif %} |
| 315 | + if ki == k_tiles - 1: |
| 316 | + # rematerialize rm and rn to save registers |
| 317 | + rcm = rm + tl.arange(0, BLOCK_M) |
| 318 | + rcn = rn + tl.arange(0, BLOCK_N) |
| 319 | + idx_m = rcm[:, None] |
| 320 | + idx_n = rcn[None, :] |
| 321 | + mask = (idx_m < M) & (idx_n < N) |
| 322 | +
|
| 323 | + # inductor generates a suffix |
| 324 | + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} |
| 325 | + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) |
| 326 | +
|
327 | 327 | """, |
328 | 328 | ) |
329 | 329 |
|
@@ -467,31 +467,30 @@ def apply_scaling( |
467 | 467 | else: |
468 | 468 | accumulator += tl.dot(a, b.T) |
469 | 469 |
|
470 | | - {% if ki == k_tiles - 1 %} |
471 | | - # Apply inverse scaling |
472 | | - offs_cm = offs_am + tl.arange(0, BLOCK_M) |
473 | | - offs_cn = offs_bn + tl.arange(0, BLOCK_N) |
474 | | - # Apply scaling |
475 | | - accumulator = apply_scaling( |
476 | | - accumulator, |
477 | | - a_scale, |
478 | | - b_scale, |
479 | | - SCALING_ROWWISE, |
480 | | - offs_cm, |
481 | | - offs_cn, |
482 | | - M, |
483 | | - N, |
484 | | - stride_a_scale_m, |
485 | | - stride_b_scale_n, |
486 | | - ) |
| 470 | + if ki == k_tiles - 1: |
| 471 | + # Apply inverse scaling |
| 472 | + offs_cm = offs_am + tl.arange(0, BLOCK_M) |
| 473 | + offs_cn = offs_bn + tl.arange(0, BLOCK_N) |
| 474 | + # Apply scaling |
| 475 | + accumulator = apply_scaling( |
| 476 | + accumulator, |
| 477 | + a_scale, |
| 478 | + b_scale, |
| 479 | + SCALING_ROWWISE, |
| 480 | + offs_cm, |
| 481 | + offs_cn, |
| 482 | + M, |
| 483 | + N, |
| 484 | + stride_a_scale_m, |
| 485 | + stride_b_scale_n, |
| 486 | + ) |
487 | 487 |
|
488 | | - idx_m = offs_cm[:, None] |
489 | | - idx_n = offs_cn[None, :] |
490 | | - mask = (idx_m < M) & (idx_n < N) |
491 | | - # inductor generates a suffix |
492 | | - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} |
493 | | - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
494 | | - {% endif %} |
| 488 | + idx_m = offs_cm[:, None] |
| 489 | + idx_n = offs_cn[None, :] |
| 490 | + mask = (idx_m < M) & (idx_n < N) |
| 491 | + # inductor generates a suffix |
| 492 | + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} |
| 493 | + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
495 | 494 | """ |
496 | 495 |
|
497 | 496 |
|
|
0 commit comments