From 5954d73f153c2e55a8808ba4dcf4c510b909004c Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Tue, 28 Oct 2025 13:27:38 +0000 Subject: [PATCH] [ROCDL] Added tensor load/store ops This patch introduces tensor load/store ops in the ROCDL dialect Specifically: tensor loads/stores <=2D and <=5D variants Tests: Added lit-tests to check MLIR -> LLVM lowering --- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 112 +++++++++++++++---- mlir/test/Dialect/LLVMIR/rocdl.mlir | 20 ++++ mlir/test/Target/LLVMIR/rocdl.mlir | 20 ++++ 3 files changed, 128 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index d2df244eb9363..6bb968c24027f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp : + FixedVectorOfLengthAndType<[length], [elem]>, + BuildableType< + "::mlir::VectorType::get({" # length # "} ," + # elem.builderCall # ")">; + +def ROCDL_V2I16Type : ROCDL_ConcreteVector; +def ROCDL_V2F16Type : ROCDL_ConcreteVector; +def ROCDL_V2I32Type : ROCDL_ConcreteVector; +def ROCDL_V2BF16Type : ROCDL_ConcreteVector; +def ROCDL_V2F32Type : ROCDL_ConcreteVector; +def ROCDL_V3I32Type : ROCDL_ConcreteVector; +def ROCDL_V4I32Type : ROCDL_ConcreteVector; +def ROCDL_V6I32Type : ROCDL_ConcreteVector; +def ROCDL_V8I32Type : ROCDL_ConcreteVector; +def ROCDL_V8BF16Type : ROCDL_ConcreteVector; +def ROCDL_V8F16Type : ROCDL_ConcreteVector; +def ROCDL_V8F32Type : ROCDL_ConcreteVector; +def ROCDL_V16BF16Type : ROCDL_ConcreteVector; +def ROCDL_V16F16Type : ROCDL_ConcreteVector; +def ROCDL_V16F32Type : ROCDL_ConcreteVector; +def ROCDL_V32F16Type : ROCDL_ConcreteVector; +def ROCDL_V32BF16Type : ROCDL_ConcreteVector; +def ROCDL_V32F32Type : ROCDL_ConcreteVector; + //===----------------------------------------------------------------------===// // Wave-level primitives //===----------------------------------------------------------------------===// @@ -805,6 +834,65 @@ def ROCDL_RawBufferAtomicCmpSwap : }]; } +//===---------------------------------------------------------------------===// +// Raw tensor load/store intrinsics: gfx12+ + +def ROCDL_TensorLoadToLds : + ROCDL_IntrOp<"tensor.load.to.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>, + Arguments<(ins ROCDL_V4I32Type:$desc0, + ROCDL_V8I32Type:$desc1, + ROCDL_V4I32Type:$desc2, + ROCDL_V4I32Type:$desc3, + I32Attr:$cpol)>{ + let description = [{ + Loads tensor data from Global to LDS. Available on gfx12+. + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cpol + }]; +} + +def ROCDL_TensorLoadToLdsD2 : + ROCDL_IntrOp<"tensor.load.to.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>, + Arguments<(ins ROCDL_V4I32Type:$desc0, + ROCDL_V8I32Type:$desc1, + I32Attr:$cpol)>{ + let description = [{ + Loads 2D tensor data from Global to LDS. Available on gfx12+. TODO + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cpol + }]; +} + +def ROCDL_TensorStoreFromLds : + ROCDL_IntrOp<"tensor.store.from.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>, + Arguments<(ins ROCDL_V4I32Type:$desc0, + ROCDL_V8I32Type:$desc1, + ROCDL_V4I32Type:$desc2, + ROCDL_V4I32Type:$desc3, + I32Attr:$cpol)>{ + let description = [{ + Stores tensor data from Global to LDS. Available on gfx12+. + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cpol + }]; +} + +def ROCDL_TensorStoreFromLdsD2 : + ROCDL_IntrOp<"tensor.store.from.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>, + Arguments<(ins ROCDL_V4I32Type:$desc0, + ROCDL_V8I32Type:$desc1, + I32Attr:$cpol)>{ + let description = [{ + Stores tensor 2D data from Global to LDS. Available on gfx12+. TODO + }]; + let assemblyFormat = [{ + attr-dict operands `cachepolicy` $cpol + }]; +} + //===---------------------------------------------------------------------===// // MI-100 and MI-200 buffer atomic floating point add intrinsic @@ -932,30 +1020,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [], }]; } -class ROCDL_ConcreteVector : - FixedVectorOfLengthAndType<[length], [elem]>, - BuildableType< - "::mlir::VectorType::get({" # length # "} ," - # elem.builderCall # ")">; - -def ROCDL_V2I16Type : ROCDL_ConcreteVector; -def ROCDL_V2F16Type : ROCDL_ConcreteVector; -def ROCDL_V2I32Type : ROCDL_ConcreteVector; -def ROCDL_V2BF16Type : ROCDL_ConcreteVector; -def ROCDL_V2F32Type : ROCDL_ConcreteVector; -def ROCDL_V3I32Type : ROCDL_ConcreteVector; -def ROCDL_V6I32Type : ROCDL_ConcreteVector; -def ROCDL_V8I32Type : ROCDL_ConcreteVector; -def ROCDL_V8BF16Type : ROCDL_ConcreteVector; -def ROCDL_V8F16Type : ROCDL_ConcreteVector; -def ROCDL_V8F32Type : ROCDL_ConcreteVector; -def ROCDL_V16BF16Type : ROCDL_ConcreteVector; -def ROCDL_V16F16Type : ROCDL_ConcreteVector; -def ROCDL_V16F32Type : ROCDL_ConcreteVector; -def ROCDL_V32F16Type : ROCDL_ConcreteVector; -def ROCDL_V32BF16Type : ROCDL_ConcreteVector; -def ROCDL_V32F32Type : ROCDL_ConcreteVector; - //===---------------------------------------------------------------------===// // 16-bit float intrinsics //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index d270ee8b089aa..0de5f38071791 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -776,6 +776,26 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>, llvm.return } +llvm.func @rocdl.tensor.load.store.ops( + %desc0 : vector<4xi32>, + %desc1 : vector<8xi32>, + %desc2 : vector<4xi32>, + %desc3 : vector<4xi32>) { + // CHECK-LABEL: @rocdl.tensor.load.store.ops( + // CHECK-SAME: %[[DESC0:.*]]: vector<4xi32>, %[[DESC1:.*]]: vector<8xi32>, %[[DESC2:.*]]: vector<4xi32>, %[[DESC3:.*]]: vector<4xi32>) + // CHECK: rocdl.tensor.load.to.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0 + // CHECK: rocdl.tensor.load.to.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0 + // CHECK: rocdl.tensor.store.from.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0 + // CHECK: rocdl.tensor.store.from.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0 + + rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0 + rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0 + + rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0 + rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0 + llvm.return +} + llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 { // CHECK-LABEL: @rocdl_8bit_floats // CHECK: rocdl.cvt.f32.bf8 diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 30126f6bff05a..eac58929795db 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -1250,6 +1250,26 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>, llvm.return %val : i32 } +llvm.func @rocdl.tensor.load.store.ops( + %desc0 : vector<4xi32>, + %desc1 : vector<8xi32>, + %desc2 : vector<4xi32>, + %desc3 : vector<4xi32>) { + // CHECK-LABEL: @rocdl.tensor.load.store.ops( + // CHECK-SAME: <4 x i32> %[[DESC0:.*]], <8 x i32> %[[DESC1:.*]], <4 x i32> %[[DESC2:.*]], <4 x i32> %[[DESC3:.*]]) + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0) + // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0) + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0) + // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0) + // CHECK: ret void + rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0 + rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0 + + rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0 + rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0 + llvm.return +} + llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 { // CHECK-LABEL: @rocdl_8bit_floats // CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)