File tree Expand file tree Collapse file tree 2 files changed +5
-5
lines changed
include/mlir/Dialect/AMDGPU/IR Expand file tree Collapse file tree 2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -1257,8 +1257,8 @@ def AMDGPU_ScaledWMMAOp
12571257 per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
12581258
12591259 The scale instructions support a block size of 16 or 32 and two tile sizes:
1260- - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32 >)
1261- - 32x16x128 with f4 format only (output: vector<8xf32 >)
1260+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<8xf32 >)
1261+ - 32x16x128 with f4 format only (output: vector<16xf32 >)
12621262
12631263 Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
12641264 (either f8E8M0FNU, or f8E4M3FN) that are packed into i32/i64 values during
Original file line number Diff line number Diff line change @@ -485,7 +485,7 @@ LogicalResult ScaledWMMAOp::verify() {
485485 << bLen;
486486 } else { // m == 32
487487 // For 32×16×128: only fp4 is supported, A is 128, B is 64.
488- if (!isF4 (aElemType))
488+ if (!isF4 (aElemType) && ! isF4 (bElemType) )
489489 return emitOpError (" 32x16x128 only supports fp4 element types" );
490490
491491 if (aLen != 128 )
@@ -513,12 +513,12 @@ LogicalResult ScaledWMMAOp::verify() {
513513 if (isE8M0 (scaleAElemType) && isE8M0 (scaleBElemType))
514514 return success ();
515515
516- // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2 |E4M3).
516+ // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M3 |E4M3).
517517 if ((isF8 (aElemType) || isF6 (aElemType)) && isE8M0 (scaleAElemType) &&
518518 isF4 (bElemType) && isE4M3 (scaleBElemType))
519519 return success ();
520520
521- // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M2 |E4M3), Scale B (E8M0).
521+ // Matrix A (F4) x Matrix B (F8|F6) with Scale A (E5M3 |E4M3), Scale B (E8M0).
522522 if (isF4 (aElemType) && isE4M3 (scaleAElemType) &&
523523 (isF8 (bElemType) || isF6 (bElemType)) && isE8M0 (scaleBElemType))
524524 return success ();
You can’t perform that action at this time.
0 commit comments