Skip to content

Commit f60465e

Browse files
authored
[Dialect] Cleanup and improve granularity of side effects (triton-lang#6476)
This PR cleans up side effect specification to use the `Arg<Type, desc, [effects...]>` trick in ODS, which is a more concise way to specify memory effects. At the same time, many ops are specifying coarser side effects than necessary, and this PR changes them to target a specific SSA operand or result where application. Most notably, the many ops that access TensorMemory or GlobalMemory were made to have finer grained side effects.
1 parent dfcdc27 commit f60465e

File tree

6 files changed

+212
-310
lines changed

6 files changed

+212
-310
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ def TT_LoadOp : TT_Op<"load", [
301301
def TT_StoreOp : TT_Op<"store", [
302302
SameLoadStoreOperandsShape,
303303
SameLoadStoreOperandsEncoding,
304-
MemoryEffects<[MemWrite<GlobalMemory>]>,
305304
TypesMatchWith<"value type matches ptr type", "ptr", "value",
306305
"getPointeeType($_self)">,
307306
TypesMatchWith<"mask type matches ptr type", "ptr", "mask",
@@ -312,7 +311,7 @@ def TT_StoreOp : TT_Op<"store", [
312311

313312
let arguments = (
314313
ins
315-
AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>:$ptr,
314+
Arg<AnyTypeOf<[TT_PtrLike, TT_TensorPtr]>, "", [MemWrite<GlobalMemory>]>:$ptr,
316315
TT_Type:$value,
317316
Optional<TT_BoolLike>:$mask,
318317
DefaultValuedAttr<DenseI32ArrayAttr, "::llvm::ArrayRef<int32_t>{}">:$boundaryCheck,
@@ -352,8 +351,6 @@ def TT_StoreOp : TT_Op<"store", [
352351
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
353352
SameOperandsAndResultShape,
354353
SameOperandsAndResultEncoding,
355-
MemoryEffects<[MemRead<GlobalMemory>]>,
356-
MemoryEffects<[MemWrite<GlobalMemory>]>,
357354
TypesMatchWith<"ptr type matches value type", "val", "ptr",
358355
"getPointerTypeSameShape($_self)">,
359356
TypesMatchWith<"mask type matches value type",
@@ -368,9 +365,14 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
368365
return old value at $ptr
369366
}];
370367

371-
let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, TT_PtrLike:$ptr,
372-
TT_Type:$val, Optional<TT_BoolLike>:$mask,
373-
TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope);
368+
let arguments = (ins
369+
TT_AtomicRMWAttr:$atomic_rmw_op,
370+
Arg<TT_PtrLike, "", [MemWrite<GlobalMemory>, MemRead<GlobalMemory>]>:$ptr,
371+
TT_Type:$val,
372+
Optional<TT_BoolLike>:$mask,
373+
TT_MemSemanticAttr:$sem,
374+
TT_MemSyncScopeAttr:$scope
375+
);
374376

375377
let results = (outs TT_Type:$result);
376378

@@ -382,9 +384,7 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
382384
}];
383385
}
384386

385-
def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead<GlobalMemory>]>,
386-
MemoryEffects<[MemWrite<GlobalMemory>]>,
387-
SameOperandsAndResultShape,
387+
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
388388
SameOperandsAndResultEncoding]> {
389389
let summary = "atomic cas";
390390

@@ -398,8 +398,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [MemoryEffects<[MemRead<GlobalMemory>]>
398398
return $old
399399
}];
400400

401-
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$cmp, TT_Type:$val,
402-
TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope);
401+
let arguments = (ins
402+
Arg<TT_PtrLike, "", [MemWrite<GlobalMemory>, MemRead<GlobalMemory>]>:$ptr,
403+
TT_Type:$cmp,
404+
TT_Type:$val,
405+
TT_MemSemanticAttr:$sem,
406+
TT_MemSyncScopeAttr:$scope);
403407

404408
let results = (outs TT_Type:$result);
405409

@@ -489,8 +493,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
489493
let hasVerifier = 1;
490494
}
491495

492-
// cat is not `pure` because it may reorder elements
493-
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
496+
def TT_CatOp : TT_Op<"cat", [Pure,
494497
SameTypeOperands,
495498
SameOperandsAndResultElementType]> {
496499
let summary = "concatenate 2 tensors";
@@ -502,8 +505,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
502505
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
503506
}
504507

505-
def TT_JoinOp : TT_Op<"join", [
506-
NoMemoryEffect, SameTypeOperands]> {
508+
def TT_JoinOp : TT_Op<"join", [Pure, SameTypeOperands]> {
507509
let summary = "join two tensors along a new, minor dimension";
508510
let description = [{
509511
For example, if the two input tensors are 4x8xf32, returns a tensor of
@@ -523,7 +525,7 @@ def TT_JoinOp : TT_Op<"join", [
523525
}
524526

525527
def TT_SplitOp : TT_Op<"split", [
526-
NoMemoryEffect,
528+
Pure,
527529
InferTypeOpWithLayoutEquivalence,
528530
TypesMatchWith<"outLHS and outRHS types match",
529531
"outLHS", "outRHS", "$_self">,
@@ -1256,18 +1258,15 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12561258
}
12571259

12581260

1259-
def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
1261+
def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
12601262
let summary = "Load from descriptor";
12611263
let description = [{
12621264
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
12631265
`desc` is a tensor descriptor object.
12641266
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
1265-
1266-
This is an escape hatch and is only there for testing/experimenting.
1267-
This op will be removed in the future.
12681267
}];
12691268
let arguments = (ins
1270-
TT_TensorDescType:$desc,
1269+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
12711270
Variadic<I32>:$indices,
12721271
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
12731272
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
@@ -1287,20 +1286,15 @@ def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [MemoryEffects<[MemRead<Globa
12871286
let hasVerifier = 1;
12881287
}
12891288

1290-
def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
1291-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1292-
]> {
1289+
def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
12931290
let summary = "store value based on descriptor";
12941291
let description = [{
12951292
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
12961293
`desc` is a tensor descriptor object.
12971294
The shape and types of `src` must match the descriptor otherwise the result is undefined.
1298-
1299-
This is an escape hatch and is only there for testing/experimenting.
1300-
This op will be removed in the future.
13011295
}];
13021296
let arguments = (ins
1303-
TT_TensorDescType:$desc,
1297+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
13041298
TT_Tensor:$src,
13051299
Variadic<I32>:$indices
13061300
);
@@ -1313,22 +1307,19 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [
13131307
let hasVerifier = 1;
13141308
}
13151309

1316-
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
1310+
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
13171311
let summary = "gather multiple rows from a descriptor into a single tensor";
13181312
let description = [{
13191313
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
1320-
load operations on targets that support it.
1314+
gather operations on targets that support it.
13211315

13221316
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
13231317
The descriptor block must have 1 row and the indices must be a 1D tensor.
13241318
Accordingly, the result is a 2D tensor multiple rows.
1325-
1326-
This is an escape hatch and is only there for testing/experimenting. This
1327-
op will be removed in the future.
13281319
}];
13291320

13301321
let arguments = (ins
1331-
TT_TensorDescType:$desc,
1322+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>]>:$desc,
13321323
RankedTensorOf<[I32]>:$x_offsets,
13331324
I32:$y_offset
13341325
);
@@ -1349,11 +1340,19 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [MemoryEffects<[MemRead<G
13491340
}];
13501341
}
13511342

1352-
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [
1353-
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
1354-
]> {
1343+
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter"> {
1344+
let summary = "scatter multiple rows from a tensor into a descriptor";
1345+
let description = [{
1346+
The `tt.descriptor_scatter` op will be lowered to NVIDIA TMA
1347+
scatter operations on targets that support it.
1348+
1349+
`desc_ptr` is a pointer to the TMA descriptor allocated in global memory.
1350+
The descriptor block must have 1 row and the indices must be a 1D tensor.
1351+
Accordingly, the result is a 2D tensor multiple rows.
1352+
}];
1353+
13551354
let arguments = (ins
1356-
TT_TensorDescType:$desc,
1355+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
13571356
RankedTensorOf<[I32]>:$x_offsets,
13581357
I32:$y_offset,
13591358
TT_Tensor:$src

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> {
7777

7878
def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
7979
AttrSizedOperandSegments,
80-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
8180
TypesMatchWith<"infer mask type from src type",
8281
"src", "mask", "getI1SameShape($_self)",
8382
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
@@ -95,9 +94,9 @@ def TTG_AsyncCopyGlobalToLocalOp : TTG_Op<"async_copy_global_to_local", [
9594
operands are the same as tt.load.
9695
}];
9796

98-
let arguments = (
99-
ins TT_PtrTensor:$src,
100-
TTG_MemDescType:$result,
97+
let arguments = (ins
98+
Arg<TT_PtrTensor, "", [MemRead<GlobalMemory>]>:$src,
99+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$result,
101100
Optional<I1Tensor>:$mask,
102101
Optional<TT_Type>:$other,
103102
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
@@ -146,11 +145,11 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
146145

147146
Explicitly deallocating a buffer is optional; see local_dealloc.
148147
}];
149-
let arguments = (
150-
ins
148+
let arguments = (ins
151149
Optional<TT_Tensor>:$src,
152150
OptionalAttr<I32Attr>:$alignment
153151
);
152+
let results = (outs TTG_MemDescType:$result);
154153

155154
let builders = [
156155
OpBuilder<(ins "Type":$result),
@@ -172,13 +171,12 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
172171
($src^)? attr-dict `:` functional-type(operands, results)
173172
}];
174173

175-
let results = (outs TTG_MemDescType:$result);
176174
let hasFolder = 1;
177175
let hasVerifier = 1;
178176
}
179177

180178
// Deallocate shared memory
181-
def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedMemory>]>]> {
179+
def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> {
182180
let summary = "dealloc buffer";
183181

184182
let description = [{
@@ -195,7 +193,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc", [MemoryEffects<[MemFree<SharedM
195193
operand.
196194
}];
197195

198-
let arguments = (ins TTG_MemDescType:$src);
196+
let arguments = (ins Arg<TTG_MemDescType, "", [MemFree<SharedMemory>]>:$src);
199197

200198
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
201199
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
@@ -251,13 +249,17 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
251249
let hasFolder = 1;
252250
}
253251

254-
def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
252+
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
255253
let summary = "Load a buffer from local memory into a distributed tensor";
256254

257255
let description = [{
258256
Load a tensor from the local memory descriptor into a distributed tensor.
259257
}];
260-
let arguments = (ins TTG_MemDescType:$src, Optional<TTG_AsyncToken> :$token);
258+
let arguments = (ins
259+
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
260+
Optional<TTG_AsyncToken>:$token
261+
);
262+
let results = (outs TT_Tensor:$result);
261263

262264
let builders = [
263265
OpBuilder<(ins "Type":$retType, "Value":$src),
@@ -268,16 +270,18 @@ def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffe
268270
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
269271
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];
270272

271-
let results = (outs TT_Tensor:$result);
272273
}
273274

274-
def TTG_LocalStoreOp : TTG_Op<"local_store", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
275+
def TTG_LocalStoreOp : TTG_Op<"local_store"> {
275276
let summary = "Store a distributed tensor into a buffer in local memory";
276277

277278
let description = [{
278279
Store a distributed tensor into a buffer in local memory.
279280
}];
280-
let arguments = (ins TT_Tensor:$src, TTG_MemDescType:$dst);
281+
let arguments = (ins
282+
TT_Tensor:$src,
283+
Arg<TTG_MemDescType, "", [MemWrite<SharedMemory>]>:$dst
284+
);
281285

282286
let hasVerifier = 1;
283287
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
@@ -313,7 +317,7 @@ def TTG_Fp4ToFpOp : TTG_Op<"fp4_to_fp", [Pure]> {
313317
}
314318

315319
// Allocate global memory
316-
def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[MemAlloc<GlobalMemory>]>]> {
320+
def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc"> {
317321
let summary = "allocate a global memory buffer";
318322
let description = [{
319323
This operation allocates a buffer in global memory that is private to the current program.
@@ -323,7 +327,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
323327
I32Attr:$nbytes,
324328
I32Attr:$alignment
325329
);
326-
let results = (outs TT_Ptr:$result);
330+
let results = (outs Arg<TT_Ptr, "", [MemAlloc<GlobalMemory>]>:$result);
327331

328332
let builders = [
329333
OpBuilder<(ins "Type":$result, "int32_t":$nbytes, "int32_t":$alignment),

0 commit comments

Comments
 (0)