@@ -41,8 +41,7 @@ include "TritonAMDGPUAttrDefs.td"
4141
4242
4343class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
44- Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
45- }
44+ Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])>;
4645
4746//
4847// Interfaces
@@ -53,8 +52,7 @@ def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
5352// ExtractSliceOp
5453//===----------------------------------------------------------------------===//
5554
56- def ExtractSliceOp
57- : TT_AMDGPU_Op<"extract_slice", [Pure]> {
55+ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
5856 let summary = "extract slice operation";
5957 let description = [{
6058 The "extract_slice" operation enables extracting a slice of a tensor in
@@ -92,8 +90,10 @@ def ExtractSliceOp
9290 size of the slice is determined by the result type.
9391 }];
9492
95- let arguments = (ins AnyRankedTensor:$source,
96- DenseI64ArrayAttr:$static_offsets);
93+ let arguments = (ins
94+ AnyRankedTensor:$source,
95+ DenseI64ArrayAttr:$static_offsets
96+ );
9797 let results = (outs AnyRankedTensor:$result);
9898
9999 let builders = [
@@ -117,6 +117,10 @@ def ExtractSliceOp
117117 let hasVerifier = 1;
118118}
119119
120+ //===----------------------------------------------------------------------===//
121+ // InstructionSchedHint
122+ //===----------------------------------------------------------------------===//
123+
120124def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
121125 let summary = "A placeholder op for instruction scheduling hints within a basic block";
122126 let description = [{
@@ -156,8 +160,11 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
156160 let assemblyFormat = [{ attr-dict }];
157161}
158162
159- def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
160- Arguments<(ins I1:$pred)> {
163+ //===----------------------------------------------------------------------===//
164+ // CondBarrierOp
165+ //===----------------------------------------------------------------------===//
166+
167+ def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier"> {
161168 let summary = "Conditionally set barriers to synchronize partial threads in a block";
162169
163170 let description = [{
@@ -170,22 +177,25 @@ def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
170177 NB. This doesn't set any memory fence.
171178 }];
172179
180+ let arguments = (ins I1:$pred);
181+
173182 let assemblyFormat = "$pred attr-dict";
174183}
175184
176- //
177- // AMD Buffer operations.
178- //
185+ //===----------------------------------------------------------------------===//
186+ // BufferLoadOp
187+ //===----------------------------------------------------------------------===//
188+
179189def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
180190 SameLoadStoreOperandsAndResultEncoding,
181191 AttrSizedOperandSegments,
182192 MemoryEffects<[MemRead<GlobalMemory>]>,
183193 TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
184194 TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
185195 TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
186- "($_op.getOperands().size () <= 3 ) || std::equal_to<>()">,
196+ "(cast<BufferLoadOp>( $_op).getMask () == nullptr ) || std::equal_to<>()">,
187197 TypesMatchWith<"result and other have the same type", "result", "other", "$_self",
188- "($_op.getOperands().size () <= 4 ) || std::equal_to<>()">,
198+ "(cast<BufferLoadOp>( $_op).getOther () == nullptr ) || std::equal_to<>()">,
189199]>{
190200 let summary = "Load from a scalar base pointer and a tensor offset";
191201 let description = [{
@@ -201,11 +211,10 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
201211 when it converts to the buffer ops because it is important for optimizing
202212 the cache memory access.
203213 }];
204- let arguments = (
205- ins
214+ let arguments = (ins
206215 TT_Ptr:$ptr,
207216 I32Tensor:$offsets,
208- I32:$stride,
217+ Optional< I32> :$stride,
209218 DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
210219 Optional<TT_BoolTensor>:$mask,
211220 Optional<TT_Tensor>:$other
@@ -215,24 +224,29 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
215224 let assemblyFormat = [{
216225 $ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)?
217226 oilist(`cacheModifier` `=` $cache)
218- `stride` `=` $stride
227+ ( `stride` `=` $stride^)?
219228 attr-dict `:` type($result)
220229 }];
221230}
222231
232+ //===----------------------------------------------------------------------===//
233+ // BufferAtomicRMWOp
234+ //===----------------------------------------------------------------------===//
235+
223236def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
237+ AttrSizedOperandSegments,
224238 SameLoadStoreOperandsAndResultEncoding,
225239 MemoryEffects<[MemRead<GlobalMemory>]>,
226240 MemoryEffects<[MemWrite<GlobalMemory>]>,
227241 TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">,
228242 TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
229243 TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
230244 TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
231- "($_op.getOperands().size () <= 4 ) || std::equal_to<>()">,
245+ "(cast<BufferAtomicRMWOp>( $_op).getMask () == nullptr ) || std::equal_to<>()">,
232246 TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
233247 TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
234248 TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
235- "($_op.getOperands().size () <= 4 ) || std::equal_to<>()">,
249+ "(cast<BufferAtomicRMWOp>( $_op).getMask () == nullptr ) || std::equal_to<>()">,
236250]>{
237251 let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
238252 let description = [{
@@ -246,13 +260,12 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
246260 the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
247261 when it converts to the buffer ops because it is important for optimizing the cache memory access.
248262 }];
249- let arguments = (
250- ins
263+ let arguments = (ins
251264 TT_AtomicRMWAttr:$atomic_rmw_op,
252265 TT_Ptr:$ptr,
253266 I32Tensor:$offsets,
254267 TT_Tensor:$value,
255- I32:$stride,
268+ Optional< I32> :$stride,
256269 TT_MemSemanticAttr:$sem,
257270 TT_MemSyncScopeAttr:$scope,
258271 Optional<TT_BoolTensor>:$mask
@@ -261,46 +274,23 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
261274
262275 let assemblyFormat = [{
263276 $atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
264- `stride` `=` $stride
277+ ( `stride` `=` $stride^)?
265278 attr-dict `:` type($result)
266279 }];
267280}
268281
269- def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
270- let summary = "Convert an mxfp tensor to bf16/fp16";
271-
272- let hasVerifier = 1;
273-
274- let description = [{
275- Compute the bf16 encoded in the given mxfp number as per
276- https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
277- }];
278- let arguments = (
279- ins
280- TT_Tensor:$src,
281- TT_Tensor:$scale,
282- TT_ScaleDotElemTypeAttr:$fp_type,
283- BoolAttr:$fastMath
284- );
285- let results = (outs TT_Tensor:$result);
286-
287- let assemblyFormat = [{
288- $src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
289- }];
290-
291- let extraClassDeclaration = [{
292- static RankedTensorType deduceOutputType(
293- TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
294- }];
295- }
282+ //===----------------------------------------------------------------------===//
283+ // BufferStoreOp
284+ //===----------------------------------------------------------------------===//
296285
297286def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
287+ AttrSizedOperandSegments,
298288 SameLoadStoreOperandsEncoding,
299289 MemoryEffects<[MemWrite<GlobalMemory>]>,
300290 TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
301291 TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
302292 TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
303- "($_op.getOperands().size () <= 4 ) || std::equal_to<>()">,
293+ "(cast<BufferStoreOp>( $_op).getMask () == nullptr ) || std::equal_to<>()">,
304294]>{
305295 let summary = "Store into scalar base pointer and a tensor offset";
306296 let description = [{
@@ -316,22 +306,53 @@ def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
316306 when it converts to the buffer ops because it is important for optimizing
317307 the cache memory access.
318308 }];
319- let arguments = (
320- ins
309+ let arguments = (ins
321310 TT_Tensor:$value,
322311 TT_Ptr:$ptr,
323312 I32Tensor:$offsets,
324- I32:$stride,
313+ Optional< I32> :$stride,
325314 DefaultValuedAttr<TT_CacheModifierAttr, "mlir::triton::CacheModifier::NONE">:$cache,
326315 Optional<TT_BoolTensor>:$mask
327316 );
328317
329318 let assemblyFormat = [{
330319 $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
331320 oilist(`cacheModifier` `=` $cache)
332- `stride` `=` $stride
321+ ( `stride` `=` $stride^)?
333322 attr-dict `:` type($value)
334323 }];
335324}
336325
326+ //===----------------------------------------------------------------------===//
327+ // UpcastMXFPOp
328+ //===----------------------------------------------------------------------===//
329+
330+ def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
331+ let summary = "Convert an mxfp tensor to bf16/fp16";
332+
333+ let hasVerifier = 1;
334+
335+ let description = [{
336+ Compute the bf16 encoded in the given mxfp number as per
337+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
338+ }];
339+ let arguments = (
340+ ins
341+ TT_Tensor:$src,
342+ TT_Tensor:$scale,
343+ TT_ScaleDotElemTypeAttr:$fp_type,
344+ BoolAttr:$fastMath
345+ );
346+ let results = (outs TT_Tensor:$result);
347+
348+ let assemblyFormat = [{
349+ $src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
350+ }];
351+
352+ let extraClassDeclaration = [{
353+ static RankedTensorType deduceOutputType(
354+ TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
355+ }];
356+ }
357+
337358#endif
0 commit comments