@@ -389,55 +389,102 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
389389 let assemblyFormat = "attr-dict";
390390}
391391
392- def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
393- let summary = "block level op mapping to tensorcore gen5 mma";
392+ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
393+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
394+ DeclareOpInterfaceMethods<DotOpInterface>,
395+ DeclareOpInterfaceMethods<MMAv5OpInterface>,
396+ SameVariadicOperandSize
397+ ]> {
398+ let summary = "block level op mapping to tensorcore gen5 mma";
394399
395- let description = [{
396- $d += matrix_multiply($a, $b).
397- If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
398- If there is a barrier the result will be safe to read after a barrier wait.
399- If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
400- and syncronize both CTAs if the op is synchronous.
401- }];
400+ let description = [{
401+ $d += matrix_multiply($a, $b).
402+ If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
403+ If there is a barrier the result will be safe to read after a barrier wait.
404+ If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
405+ and syncronize both CTAs if the op is synchronous.
406+ }];
402407
403- let arguments = (ins TTG_MemDescType:$a,
404- TTG_MemDescType:$b,
405- TTG_MemDescType:$d,
406- I1:$useD,
407- I1:$pred,
408- Optional<TTG_MemDescType>:$barrier,
409- OptionalAttr<UnitAttr>:$two_ctas);
408+ let arguments = (ins
409+ TTG_MemDescType:$a,
410+ TTG_MemDescType:$b,
411+ TTG_MemDescType:$d,
412+ I1:$useD,
413+ I1:$pred,
414+ Variadic<TTG_MemDescType>:$barriers,
415+ Variadic<I1>:$barrier_preds,
416+ OptionalAttr<UnitAttr>:$two_ctas
417+ );
410418
411- // TODO: improve printing format.
412- let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
419+ let builders = [
420+ OpBuilder<(ins
421+ "Value":$a, "Value":$b, "Value":$d, "Value":$useD, "Value":$pred,
422+ CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers,
423+ CArg<"ValueRange", "{}">:$barrier_preds)>
424+ ];
425+
426+ let assemblyFormat = [{
427+ $a`,` $b`,` $d`,` $useD`,` $pred
428+ `` custom<BarriersAndPreds>($barriers, $barrier_preds)
429+ attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
430+ qualified(type($d)) (`,` qualified(type($barriers))^)?
431+ }];
413432}
414433
415- def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>, DeclareOpInterfaceMethods<MMAv5OpInterface>]> {
416- let summary = "block level op mapping to tensorcore gen5 mma";
434+ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
435+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
436+ DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
437+ DeclareOpInterfaceMethods<MMAv5OpInterface>,
438+ SameVariadicOperandSize
439+ ]> {
440+ let summary = "block level op mapping to tensorcore gen5 mma";
417441
418- let description = [{
419- $d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
420- If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
421- If there is a barrier the result will be safe to read after a barrier wait.
422- }];
442+ let description = [{
443+ $d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
444+ If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
445+ If there is a barrier the result will be safe to read after a barrier wait.
446+ }];
423447
424- let arguments = (ins TTG_MemDescType:$a,
425- TTG_MemDescType:$b,
426- TTG_MemDescType:$d,
427- TTG_MemDescType:$a_scale,
428- TTG_MemDescType:$b_scale,
429- TT_ScaleDotElemTypeAttr:$a_type,
430- TT_ScaleDotElemTypeAttr:$b_type,
431- I1:$useD,
432- I1:$pred,
433- Optional<TTG_MemDescType>:$barrier);
434- let extraClassDeclaration = [{
435- int64_t getBlockM();
436- int64_t getBlockN();
437- int64_t getBlockK();
438- }];
439- // TODO: improve printing format.
440- let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
448+ let arguments = (ins
449+ TTG_MemDescType:$a,
450+ TTG_MemDescType:$b,
451+ TTG_MemDescType:$d,
452+ TTG_MemDescType:$a_scale,
453+ TTG_MemDescType:$b_scale,
454+ TT_ScaleDotElemTypeAttr:$a_type,
455+ TT_ScaleDotElemTypeAttr:$b_type,
456+ I1:$useD,
457+ I1:$pred,
458+ Variadic<TTG_MemDescType>:$barriers,
459+ Variadic<I1>:$barrier_preds
460+ );
461+ let extraClassDeclaration = [{
462+ int64_t getBlockM();
463+ int64_t getBlockN();
464+ int64_t getBlockK();
465+ }];
466+
467+ let builders = [
468+ // Namespaces need to be prefixed so ODS prefers our
469+ // custom builder signature over the default-generated one.
470+ OpBuilder<(ins
471+ "::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d,
472+ "::mlir::Value":$a_scale, "::mlir::Value":$b_scale,
473+ "::mlir::triton::ScaleDotElemType":$a_type,
474+ "::mlir::triton::ScaleDotElemType":$b_type,
475+ "::mlir::Value":$useD, "::mlir::Value":$pred,
476+ CArg<"::mlir::ValueRange", "{}">:$barriers,
477+ CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
478+ ];
479+
480+ let assemblyFormat = [{
481+ $a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred
482+ `lhs` `=` $a_type `rhs` `=` $b_type
483+ `` custom<BarriersAndPreds>($barriers, $barrier_preds)
484+ attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
485+ qualified(type($d)) `,` qualified(type($a_scale)) `,`
486+ qualified(type($b_scale)) (`,` qualified(type($barriers))^)?
487+ }];
441488}
442489
443490def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
0 commit comments