@@ -417,7 +417,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
417417 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
418418 DeclareOpInterfaceMethods<DotOpInterface>,
419419 DeclareOpInterfaceMethods<MMAv5OpInterface>,
420- SameVariadicOperandSize
420+ AttrSizedOperandSegments
421421]> {
422422 let summary = "block level op mapping to tensorcore gen5 mma";
423423
@@ -427,29 +427,36 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [
427427 If there is a barrier the result will be safe to read after a barrier wait.
428428 If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
429429 and syncronize both CTAs if the op is synchronous.
430+
431+ This operation takes and produces an optional token to indicate TMEM read
432+ and write on its accumulator operand. When the tokens are present, they can
433+ be used to check aliasing and modref on the accumulator memory.
430434 }];
431435
432436 let arguments = (ins
433437 TTG_MemDescType:$a,
434438 TTG_MemDescType:$b,
435439 TTG_MemDescType:$d,
440+ Optional<TTG_AsyncToken>:$acc_dep,
436441 I1:$useD,
437442 I1:$pred,
438443 Variadic<TTG_MemDescType>:$barriers,
439444 Variadic<I1>:$barrier_preds,
440445 OptionalAttr<UnitAttr>:$two_ctas
441446 );
447+ let results = (outs Optional<TTG_AsyncToken>:$token);
442448
443449 let builders = [
444- OpBuilder<(ins
445- "Value":$a, "Value":$b, "Value":$d, "Value":$useD, "Value":$pred,
446- CArg<"bool", "false">:$two_ctas, CArg<"ValueRange", "{}">:$barriers,
450+ OpBuilder<(ins "Type":$token,
451+ "Value":$a, "Value":$b, "Value":$d, "Value":$acc_dep, "Value":$useD,
452+ "Value":$pred, CArg<"bool", "false">:$two_ctas,
453+ CArg<"ValueRange", "{}">:$barriers,
447454 CArg<"ValueRange", "{}">:$barrier_preds)>
448455 ];
449456
450457 let assemblyFormat = [{
451- $a`,` $b`,` $d`,` $useD `,` $pred
452- `` custom<BarriersAndPreds>($barriers, $barrier_preds)
458+ $a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $useD`,`
459+ $pred `` custom<BarriersAndPreds>($barriers, $barrier_preds)
453460 attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
454461 qualified(type($d)) (`,` qualified(type($barriers))^)?
455462 }];
@@ -459,20 +466,25 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
459466 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
460467 DeclareOpInterfaceMethods<DotOpInterface, ["verifyDims", "verifyOutputDims"]>,
461468 DeclareOpInterfaceMethods<MMAv5OpInterface>,
462- SameVariadicOperandSize
469+ AttrSizedOperandSegments
463470]> {
464471 let summary = "block level op mapping to tensorcore gen5 mma";
465472
466473 let description = [{
467474 $d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
468475 If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
469476 If there is a barrier the result will be safe to read after a barrier wait.
477+
478+ This operation takes and produces an optional token to indicate TMEM read
479+ and write on its accumulator operand. When the tokens are present, they can
480+ be used to check aliasing and modref on the accumulator memory.
470481 }];
471482
472483 let arguments = (ins
473484 TTG_MemDescType:$a,
474485 TTG_MemDescType:$b,
475486 TTG_MemDescType:$d,
487+ Optional<TTG_AsyncToken>:$acc_dep,
476488 TTG_MemDescType:$a_scale,
477489 TTG_MemDescType:$b_scale,
478490 TT_ScaleDotElemTypeAttr:$a_type,
@@ -482,6 +494,8 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
482494 Variadic<TTG_MemDescType>:$barriers,
483495 Variadic<I1>:$barrier_preds
484496 );
497+ let results = (outs Optional<TTG_AsyncToken>:$token);
498+
485499 let extraClassDeclaration = [{
486500 int64_t getBlockM();
487501 int64_t getBlockN();
@@ -491,19 +505,19 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [
491505 let builders = [
492506 // Namespaces need to be prefixed so ODS prefers our
493507 // custom builder signature over the default-generated one.
494- OpBuilder<(ins
508+ OpBuilder<(ins "::mlir::Type":$token,
495509 "::mlir::Value":$a, "::mlir::Value":$b, "::mlir::Value":$d,
496- "::mlir::Value":$a_scale , "::mlir::Value":$b_scale ,
497- "::mlir::triton::ScaleDotElemType":$a_type,
510+ "::mlir::Value":$acc_dep , "::mlir::Value":$a_scale ,
511+ "::mlir::Value":$b_scale, "::mlir:: triton::ScaleDotElemType":$a_type,
498512 "::mlir::triton::ScaleDotElemType":$b_type,
499513 "::mlir::Value":$useD, "::mlir::Value":$pred,
500514 CArg<"::mlir::ValueRange", "{}">:$barriers,
501515 CArg<"::mlir::ValueRange", "{}">:$barrier_preds)>
502516 ];
503517
504518 let assemblyFormat = [{
505- $a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD `,` $pred
506- `lhs` `=` $a_type `rhs` `=` $b_type
519+ $a `,` $b `,` $d `` custom<Token>($acc_dep, type($token)) `,` $a_scale `,`
520+ $b_scale `,` $useD `,` $pred `lhs` `=` $a_type `rhs` `=` $b_type
507521 `` custom<BarriersAndPreds>($barriers, $barrier_preds)
508522 attr-dict `:` qualified(type($a)) `,` qualified(type($b)) `,`
509523 qualified(type($d)) `,` qualified(type($a_scale)) `,`
@@ -517,27 +531,55 @@ def TTNG_TMEMLoadOp : TTNG_Op<"tmem_load"> {
517531 let description = [{
518532 This is similar to ttg.local_load except the result layout is restricted to only few possibility.
519533 Therefore we cannot combine this op with any convert layout like local_load.
534+
535+ This operation takes and produces an optional token to indicate TMEM read
536+ on its source operand. When the tokens are present, they can
537+ be used to check aliasing and modref on the TMEM buffer.
538+ }];
539+ let arguments = (ins
540+ Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src,
541+ Optional<TTG_AsyncToken>:$dep
542+ );
543+ let results = (outs
544+ TT_Tensor:$result,
545+ Optional<TTG_AsyncToken>:$token
546+ );
547+
548+ let assemblyFormat = [{
549+ $src `` custom<Token>($dep, type($token))
550+ attr-dict `:` qualified(type($src)) `->` type($result)
520551 }];
521- let arguments = (ins Arg<TTG_MemDescType, "", [MemRead<TensorMemory>]>:$src);
522552
523- let assemblyFormat = [{$src attr-dict `:` qualified(type($src)) `->` type($result)}];
524- let results = (outs TT_Tensor:$result);
525553 let hasVerifier = 1;
554+
555+ let extraClassDeclaration = [{
556+ RankedTensorType getType() { return getResult().getType(); }
557+ operator TypedValue<RankedTensorType>() { return getResult(); }
558+ }];
526559}
527560
528561def TTNG_TMEMStoreOp : TTNG_Op<"tmem_store"> {
529562 let summary = "Store a distributed tensor into a buffer in tensor memory";
530563
531564 let description = [{
532- This is similar to ttg.local_local except the source layout is restricted to only few possibility.
565+ This is similar to ttg.local_store except the source layout is restricted to only few possibility.
566+
567+ This operation takes and produces an optional token to indicate TMEM write
568+ on its source operand. When the tokens are present, they can
569+ be used to check aliasing and modref on the TMEM buffer.
533570 }];
534571 let arguments = (ins
535572 Arg<TTG_MemDescType, "", [MemWrite<TensorMemory>]>:$dst,
573+ Optional<TTG_AsyncToken>:$dep,
536574 TT_Tensor:$src,
537575 I1:$pred
538576 );
577+ let results = (outs Optional<TTG_AsyncToken>:$token);
539578
540- let assemblyFormat = [{$src `,` $dst `,` $pred attr-dict `:` type($src) `->` qualified(type($dst))}];
579+ let assemblyFormat = [{
580+ $src `,` $dst `` custom<Token>($dep, type($token)) `,` $pred
581+ attr-dict `:` type($src) `->` qualified(type($dst))
582+ }];
541583 let hasVerifier = 1;
542584}
543585
@@ -551,13 +593,21 @@ def TTNG_TMEMAllocOp : TTNG_Op<"tmem_alloc", [DeclareOpInterfaceMethods<MemoryEf
551593 Explicitly deallocating a buffer is optional; see local_dealloc.
552594 }];
553595 let arguments = (ins Optional<TT_Tensor>:$src);
596+ let results = (outs
597+ TTG_MemDescType:$result,
598+ Optional<TTG_AsyncToken>:$token
599+ );
554600
555601 let assemblyFormat = [{
556602 ($src^)? attr-dict `:` functional-type(operands, results)
557603 }];
558604
559- let results = (outs TTG_MemDescType:$result);
560605 let hasVerifier = 1;
606+
607+ let extraClassDeclaration = [{
608+ triton::gpu::MemDescType getType() { return getResult().getType(); }
609+ operator TypedValue<triton::gpu::MemDescType>() { return getResult(); }
610+ }];
561611}
562612
563613def TTNG_TMEMSubSliceOp : TTNG_Op<"tmem_subslice", [Pure]> {
0 commit comments