Skip to content

Commit 6bb9119

Browse files
olegshyshkovGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Add xla.get_dynamic_dim_size op and its lowering.
The new op is needed to implement PadToStatic custom call. PiperOrigin-RevId: 837107619
1 parent 431a46d commit 6bb9119

File tree

6 files changed

+141
-14
lines changed

6 files changed

+141
-14
lines changed

xla/codegen/emitters/ir/tests/ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,12 @@ func.func @workgroup_id_op() -> (index, index, index) {
167167
// CHECK: [[WORKGROUP_ID_X:.*]] = xla.workgroup_id x {xla.range = [0 : index, 1023 : index]}
168168
// CHECK: [[WORKGROUP_ID_Y:.*]] = xla.workgroup_id y
169169
// CHECK: [[WORKGROUP_ID_Z:.*]] = xla.workgroup_id z
170+
171+
// -----
172+
173+
func.func @get_dynamic_dim_size(%in: tensor<16x8x4xf32>) -> (i32) {
174+
%out = xla.get_dynamic_dim_size %in 1 : tensor<16x8x4xf32>
175+
func.return %out : i32
176+
}
177+
// CHECK-LABEL: @get_dynamic_dim_size
178+
// CHECK: xla.get_dynamic_dim_size

xla/codegen/emitters/ir/xla_ops.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,19 @@ def WorkGroupIdOp : XLA_Op<"workgroup_id", [
298298
let results = (outs Index);
299299
}
300300

301+
def GetDynamicDimSizeOp : XLA_Op<"get_dynamic_dim_size", [
302+
Pure,
303+
]> {
304+
let summary = "Returns the dynamic size of a dimension. The dynamic sizes are "
305+
"stored in the same buffer, after the main values as an array "
306+
"of s32. The `dim` argument can be larger than `tensor`'s rank, "
307+
"because XLA has passes like flatten_tensors that only change "
308+
"the view of the memory.";
309+
let arguments =(ins AnyStaticShapeTensor:$tensor, I64Attr:$dim);
310+
let results = (outs I32:$result);
311+
312+
let assemblyFormat = "$tensor $dim attr-dict `:` type($tensor)";
313+
}
301314

302315
#endif // XLA_CODEGEN_EMITTERS_IR_XLA_OPS
303316

xla/codegen/emitters/transforms/flatten_tensors.cc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ limitations under the License.
3131
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3232
#include "mlir/IR/AffineExpr.h"
3333
#include "mlir/IR/Attributes.h"
34+
#include "mlir/IR/Builders.h"
3435
#include "mlir/IR/BuiltinAttributes.h"
3536
#include "mlir/IR/BuiltinOps.h"
3637
#include "mlir/IR/BuiltinTypeInterfaces.h"
3738
#include "mlir/IR/BuiltinTypes.h"
38-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3939
#include "mlir/IR/MLIRContext.h"
4040
#include "mlir/IR/PatternMatch.h"
4141
#include "mlir/IR/TypeRange.h"
@@ -46,11 +46,12 @@ limitations under the License.
4646
#include "mlir/Pass/Pass.h"
4747
#include "mlir/Support/LLVM.h"
4848
#include "mlir/Support/LogicalResult.h"
49+
#include "mlir/Support/WalkResult.h"
4950
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
5051
#include "xla/backends/cpu/codegen/emitters/ir/xla_cpu_ops.h"
5152
#include "xla/backends/gpu/codegen/emitters/ir/xla_gpu_ops.h"
53+
#include "xla/codegen/emitters/ir/xla_ops.h"
5254
#include "xla/hlo/analysis/indexing_analysis.h"
53-
#include "xla/hlo/analysis/symbolic_expr.h"
5455
#include "xla/layout_util.h"
5556
#include "xla/shape_util.h"
5657
#include "xla/xla_data.pb.h"
@@ -748,6 +749,28 @@ struct RewriteSyncThreads : OpRewritePattern<gpu::SyncThreadsOp> {
748749
}
749750
};
750751

752+
struct RewriteGetDynamicDimSizeOp : OpRewritePattern<GetDynamicDimSizeOp> {
753+
using OpRewritePattern::OpRewritePattern;
754+
755+
LogicalResult matchAndRewrite(GetDynamicDimSizeOp op,
756+
PatternRewriter& rewriter) const override {
757+
auto tensor = op.getTensor();
758+
auto tensor_type = tensor.getType();
759+
if (tensor_type.getRank() < 2) {
760+
return rewriter.notifyMatchFailure(op, "the tensor is already flat");
761+
}
762+
763+
auto tensor_1D = rewriter
764+
.create<UnrealizedConversionCastOp>(
765+
op.getLoc(), GetFlattenedType(tensor_type), tensor)
766+
.getResult(0);
767+
rewriter.replaceOpWithNewOp<GetDynamicDimSizeOp>(op, tensor_1D,
768+
op.getDim());
769+
770+
return mlir::success();
771+
}
772+
};
773+
751774
class FlattenTensorsPass
752775
: public impl::FlattenTensorsPassBase<FlattenTensorsPass> {
753776
public:
@@ -760,8 +783,10 @@ class FlattenTensorsPass
760783
RewriteAllocateShared,
761784
RewriteAtomicRMW,
762785
RewriteConstant,
786+
RewriteCpuLoad,
763787
RewriteFor,
764788
RewriteFunctionSignatures,
789+
RewriteGetDynamicDimSizeOp,
765790
RewriteIf,
766791
RewriteIndexSwitch,
767792
RewritePureCall,
@@ -771,8 +796,7 @@ class FlattenTensorsPass
771796
RewriteVectorExtract,
772797
RewriteVectorFromElements,
773798
RewriteVectorInsert,
774-
RewriteVectorTransferRead,
775-
RewriteCpuLoad
799+
RewriteVectorTransferRead
776800
>(mlir_context);
777801
// clang-format on
778802
ApplyIndexingOp::getCanonicalizationPatterns(patterns, mlir_context);

xla/codegen/emitters/transforms/lower_tensors.cc

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,55 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
13251325
const DeviceSpec& device_spec_;
13261326
};
13271327

1328+
class RewriteGetDynamicDimSize : public OpRewritePattern<GetDynamicDimSizeOp> {
1329+
using OpRewritePattern::OpRewritePattern;
1330+
1331+
LogicalResult matchAndRewrite(
1332+
GetDynamicDimSizeOp op, mlir::PatternRewriter& rewriter) const override {
1333+
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1334+
1335+
auto tensor = op.getTensor();
1336+
auto tensor_type = mlir::dyn_cast<mlir::RankedTensorType>(tensor.getType());
1337+
1338+
Type element_type = tensor_type.getElementType();
1339+
int64_t num_elements = tensor_type.getNumElements();
1340+
std::optional<int> sub_byte_width = GetSubByteBitWidth(element_type);
1341+
if (sub_byte_width) {
1342+
element_type = b.getI8Type();
1343+
// Elements are packed.
1344+
num_elements = CeilOfRatio<int64_t>(num_elements, 8 / *sub_byte_width);
1345+
}
1346+
1347+
// The offset of the dim size from the start of the buffer. The dynamic dim
1348+
// sizes are stored after the tensor data as a tail-allocated metadata of
1349+
// s32 type.
1350+
int64_t dynamic_size_offset_in_bytes =
1351+
num_elements * element_type.getIntOrFloatBitWidth() / 8 +
1352+
op.getDim() * b.getI32Type().getWidth() / 8;
1353+
1354+
int64_t alignment = dynamic_size_offset_in_bytes % 4;
1355+
// TODO(b/463569416): Support unaligned loads.
1356+
if (alignment != 0) {
1357+
return op->emitOpError("dynamic size offset is not 4-byte aligned");
1358+
}
1359+
1360+
auto ptr_type = ml::LLVMPointerType::get(b.getContext());
1361+
Value tensor_ptr =
1362+
b.create<UnrealizedConversionCastOp>(ptr_type, tensor).getResult(0);
1363+
1364+
Value addr_offset =
1365+
b.create<ml::ConstantOp>(b.getI64Type(), dynamic_size_offset_in_bytes);
1366+
1367+
Value addr_int = b.create<ml::PtrToIntOp>(b.getI64Type(), tensor_ptr);
1368+
Value metadata_addr_int = b.create<ml::AddOp>(addr_int, addr_offset);
1369+
Value metadata_addr = b.create<ml::IntToPtrOp>(ptr_type, metadata_addr_int);
1370+
1371+
rewriter.replaceOpWithNewOp<ml::LoadOp>(op, b.getI32Type(), metadata_addr);
1372+
1373+
return success();
1374+
}
1375+
};
1376+
13281377
class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
13291378
public:
13301379
explicit LowerTensorsPass(const LowerTensorsPassOptions& options)
@@ -1351,10 +1400,11 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
13511400
mlir::RewritePatternSet tensor_patterns(mlir_context);
13521401

13531402
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, device_spec_);
1354-
tensor_patterns
1355-
.add<RewriteAllocateShared, RewriteNonScalarConstants,
1356-
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
1357-
RewriteTensorInsert, RewriteTransferWrite>(mlir_context);
1403+
tensor_patterns.add<RewriteAllocateShared, RewriteGetDynamicDimSize,
1404+
RewriteNonScalarConstants, RewriteSyncThreads,
1405+
RewriteTensorExtract, RewriteTensorInsert,
1406+
RewriteTransferRead, RewriteTransferWrite>(
1407+
mlir_context);
13581408
if (mlir::failed(mlir::applyPatternsGreedily(getOperation(),
13591409
std::move(tensor_patterns)))) {
13601410
signalPassFailure();
@@ -1396,14 +1446,8 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
13961446
if (func.getArgAttr(base.getArgNumber(), "xla.invariant")) {
13971447
load.setInvariant(true);
13981448
}
1399-
return;
14001449
}
14011450
}
1402-
if (!device_spec_.IsCpu()) {
1403-
load.emitOpError(
1404-
"load op address is not (a GEP of) a function argument");
1405-
signalPassFailure();
1406-
}
14071451
});
14081452
}
14091453

xla/codegen/emitters/transforms/tests/flatten_tensors.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,13 @@ func.func @constant_vector() -> vector<2x3xf32> {
398398
// CHECK-LABEL: func.func @constant_vector
399399
// CHECK-SAME: -> vector<6xf32>
400400
// CHECK-NOT: builtin.unrealized_conversion_cast
401+
402+
// -----
403+
404+
func.func @get_dynamic_dim_size(%in: tensor<16x8x4xf32>) -> (i32) {
405+
%out = xla.get_dynamic_dim_size %in 1 : tensor<16x8x4xf32>
406+
func.return %out : i32
407+
}
408+
// CHECK-LABEL: func.func @get_dynamic_dim_size(
409+
// CHECK-SAME: %[[TENSOR:.*]]: tensor<512xf32>) -> i32 {
410+
// CHECK: xla.get_dynamic_dim_size %[[TENSOR]] 1 : tensor<512xf32>

xla/codegen/emitters/transforms/tests/lower_tensors.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,3 +1106,30 @@ func.func @transfer_write_f4(%arg0: tensor<43xf4E2M1FN> {xla.slice_index = 1},
11061106
// CHECK-LABEL: @transfer_write_f4
11071107
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %arg0[0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<22 x i8>
11081108
// CHECK: %[[OUT:.*]] = builtin.unrealized_conversion_cast %{{.*}} : vector<2xf4E2M1FN> to vector<2xi4>
1109+
1110+
// -----
1111+
1112+
func.func @get_dynamic_dim_size(%arg0: tensor<512xf32>) -> i32 {
1113+
%0 = xla.get_dynamic_dim_size %arg0 1 : tensor<512xf32>
1114+
func.return %0 : i32
1115+
}
1116+
// CHECK-LABEL: @get_dynamic_dim_size
1117+
// CHECK: llvm.mlir.constant(2052 : i64) : i64
1118+
1119+
// -----
1120+
1121+
func.func @get_dynamic_dim_size_sub_byte_width(%arg0: tensor<512xi4>) -> i32 {
1122+
%0 = xla.get_dynamic_dim_size %arg0 1 : tensor<512xi4>
1123+
func.return %0 : i32
1124+
}
1125+
// CHECK-LABEL: @get_dynamic_dim_size_sub_byte_width
1126+
// CHECK: llvm.mlir.constant(260 : i64) : i64
1127+
1128+
// // -----
1129+
1130+
func.func @get_dynamic_dim_size_unaligned(%arg0: tensor<7xf16>) -> i32 {
1131+
// expected-error @+1 {{'xla.get_dynamic_dim_size' op dynamic size offset is not 4-byte aligned}}
1132+
%0 = xla.get_dynamic_dim_size %arg0 1 : tensor<7xf16>
1133+
func.return %0 : i32
1134+
}
1135+

0 commit comments

Comments
 (0)