@@ -301,7 +301,6 @@ def TT_LoadOp : TT_Op<"load", [
301301def TT_StoreOp : TT_Op<"store", [
302302 SameLoadStoreOperandsShape,
303303 SameLoadStoreOperandsEncoding,
304- MemoryEffects<[MemWrite<GlobalMemory>]>,
305304 TypesMatchWith<"value type matches ptr type", "ptr", "value",
306305 "getPointeeType($_self)">,
307306 TypesMatchWith<"mask type matches ptr type", "ptr", "mask",
@@ -312,7 +311,7 @@ def TT_StoreOp : TT_Op<"store", [
312311
313312 let arguments = (
314313 ins
315- AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
314+ Arg< AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>, "", [MemWrite<GlobalMemory> ]>:$ptr,
316315 TT_Type:$value,
317316 Optional<TT_BoolLike>:$mask,
318317 DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
@@ -352,8 +351,6 @@ def TT_StoreOp : TT_Op<"store", [
352351def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
353352 SameOperandsAndResultShape,
354353 SameOperandsAndResultEncoding,
355- MemoryEffects<[MemRead<GlobalMemory>]>,
356- MemoryEffects<[MemWrite<GlobalMemory>]>,
357354 TypesMatchWith<"ptr type matches value type", "val", "ptr",
358355 "getPointerTypeSameShape($_self)">,
359356 TypesMatchWith<"mask type matches value type",
@@ -368,9 +365,14 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
368365 return old value at $ptr
369366 }];
370367
371- let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
372- TT_Type:$val, Optional<TT_BoolLike>:$mask,
373- TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope);
368+ let arguments = (ins
369+ TT_AtomicRMWAttr:$atomic_rmw_op,
370+ Arg<TT_PtrLike, "", [MemWrite<GlobalMemory>, MemRead<GlobalMemory>]>:$ptr,
371+ TT_Type:$val,
372+ Optional<TT_BoolLike>:$mask,
373+ TT_MemSemanticAttr:$sem,
374+ TT_MemSyncScopeAttr:$scope
375+ );
374376
375377 let results = (outs TT_Type:$result);
376378
@@ -382,9 +384,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
382384 }];
383385}
384386
385- def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead<GlobalMemory>]>,
386- MemoryEffects<[MemWrite<GlobalMemory>]>,
387- SameOperandsAndResultShape,
387+ def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
388388 SameOperandsAndResultEncoding]> {
389389 let summary = "atomic cas";
390390
@@ -398,8 +398,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead<GlobalMemory>]>
398398 return $old
399399 }];
400400
401- let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val,
402- TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope);
401+ let arguments = (ins
402+ Arg<TT_PtrLike, "", [MemWrite<GlobalMemory>, MemRead<GlobalMemory>]>:$ptr,
403+ TT_Type:$cmp,
404+ TT_Type:$val,
405+ TT_MemSemanticAttr:$sem,
406+ TT_MemSyncScopeAttr:$scope);
403407
404408 let results = (outs TT_Type:$result);
405409
@@ -489,8 +493,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
489493 let hasVerifier = 1;
490494}
491495
492- // cat is not `pure` because it may reorder elements
493- def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
496+ def TT_CatOp : TT_Op<"cat", [Pure,
494497 SameTypeOperands,
495498 SameOperandsAndResultElementType]> {
496499 let summary = "concatenate 2 tensors";
@@ -502,8 +505,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
502505 let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
503506}
504507
505- def TT_JoinOp : TT_Op<"join", [
506- NoMemoryEffect, SameTypeOperands]> {
508+ def TT_JoinOp : TT_Op<"join", [Pure, SameTypeOperands]> {
507509 let summary = "join two tensors along a new, minor dimension";
508510 let description = [{
509511 For example, if the two input tensors are 4x8xf32, returns a tensor of
@@ -523,7 +525,7 @@ def TT_JoinOp : TT_Op<"join", [
523525}
524526
525527def TT_SplitOp : TT_Op<"split", [
526- NoMemoryEffect ,
528+ Pure ,
527529 InferTypeOpWithLayoutEquivalence,
528530 TypesMatchWith<"outLHS and outRHS types match",
529531 "outLHS", "outRHS", "$_self">,
@@ -1256,18 +1258,15 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12561258}
12571259
12581260
1259- def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>] > {
1261+ def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
12601262 let summary = "Load from descriptor";
12611263 let description = [{
12621264 This operation will be lowered to Nvidia TMA load operation on targets supporting it.
12631265 `desc` is a tensor descriptor object.
12641266 The destination tensor type and shape must match the descriptor otherwise the result is undefined.
1265-
1266- This is an escape hatch and is only there for testing/experimenting.
1267- This op will be removed in the future.
12681267 }];
12691268 let arguments = (ins
1270- TT_TensorDescType:$desc,
1269+ Arg< TT_TensorDescType, "", [MemRead<GlobalMemory>]> :$desc,
12711270 Variadic<I32>:$indices,
12721271 DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
12731272 DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
@@ -1287,20 +1286,15 @@ def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [MemoryEffects<[MemRead<Globa
12871286 let hasVerifier = 1;
12881287}
12891288
1290- def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
1291- MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1292- ]> {
1289+ def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
12931290 let summary = "store value based on descriptor";
12941291 let description = [{
12951292 This operation will be lowered to Nvidia TMA store operation on targets supporting it.
12961293 `desc` is a tensor descriptor object.
12971294 The shape and types of `src` must match the descriptor otherwise the result is undefined.
1298-
1299- This is an escape hatch and is only there for testing/experimenting.
1300- This op will be removed in the future.
13011295 }];
13021296 let arguments = (ins
1303- TT_TensorDescType:$desc,
1297+ Arg< TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]> :$desc,
13041298 TT_Tensor:$src,
13051299 Variadic<I32>:$indices
13061300 );
@@ -1313,22 +1307,19 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
13131307 let hasVerifier = 1;
13141308}
13151309
1316- def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<GlobalMemory>]>] > {
1310+ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
13171311 let summary = "gather multiple rows from a descriptor into a single tensor";
13181312 let description = [{
13191313 The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
1320- load operations on targets that support it.
1314+ gather operations on targets that support it.
13211315
13221316 `desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
13231317 The descriptor block must have 1 row and the indices must be a 1D tensor.
13241318 Accordingly, the result is a 2D tensor multiple rows.
1325-
1326- This is an escape hatch and is only there for testing/experimenting. This
1327- op will be removed in the future.
13281319 }];
13291320
13301321 let arguments = (ins
1331- TT_TensorDescType:$desc,
1322+ Arg< TT_TensorDescType, "", [MemRead<GlobalMemory>]> :$desc,
13321323 RankedTensorOf<[I32]>:$x_offsets,
13331324 I32:$y_offset
13341325 );
@@ -1349,11 +1340,19 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<G
13491340 }];
13501341}
13511342
1352- def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [
1353- MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1354- ]> {
1343+ def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter"> {
1344+ let summary = "scatter multiple rows from a tensor into a descriptor";
1345+ let description = [{
1346+ The `tt.descriptor_scatter` op will be lowered to NVIDIA TMA
1347+ scatter operations on targets that support it.
1348+
1349+ `desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1350+ The descriptor block must have 1 row and the indices must be a 1D tensor.
1351+ Accordingly, the result is a 2D tensor multiple rows.
1352+ }];
1353+
13551354 let arguments = (ins
1356- TT_TensorDescType:$desc,
1355+ Arg< TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]> :$desc,
13571356 RankedTensorOf<[I32]>:$x_offsets,
13581357 I32:$y_offset,
13591358 TT_Tensor:$src
0 commit comments