Skip to content

Commit e330cc1

Browse files
committed
Minor copilot comments
1 parent 0291d87 commit e330cc1

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,8 +1257,8 @@ def AMDGPU_ScaledWMMAOp
12571257
per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
12581258

12591259
The scale instructions support a block size of 16 or 32 and two tile sizes:
1260-
- 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
1261-
- 32x16x128 with f4 format only (output: vector<8xf32>)
1260+
- 16x16x128 with mixed f8/f6/f4 formats (output: vector<8xf32>)
1261+
- 32x16x128 with f4 format only (output: vector<16xf32>)
12621262

12631263
Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
12641264
(either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ LogicalResult ScaledWMMAOp::verify() {
485485
<< bLen;
486486
} else { // m == 32
487487
// For 32×16×128: only fp4 is supported, A is 128, B is 64.
488-
if (!isF4(aElemType))
488+
if (!isF4(aElemType) && !isF4(bElemType))
489489
return emitOpError("32x16x128 only supports fp4 element types");
490490

491491
if (aLen != 128)
@@ -513,12 +513,12 @@ LogicalResult ScaledWMMAOp::verify() {
513513
if (isE8M0(scaleAElemType) && isE8M0(scaleBElemType))
514514
return success();
515515

516-
// Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3).
516+
// Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3|E4M3).
517517
if ((isF8(aElemType) || isF6(aElemType)) && isE8M0(scaleAElemType) &&
518518
isF4(bElemType) && isE4M3(scaleBElemType))
519519
return success();
520520

521-
// Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M2|E4M3), Scale B (E8M0).
521+
// Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3|E4M3), Scale B (E8M0).
522522
if (isF4(aElemType) && isE4M3(scaleAElemType) &&
523523
(isF8(bElemType) || isF6(bElemType)) && isE8M0(scaleBElemType))
524524
return success();

0 commit comments

Comments
 (0)