Skip to content

Commit 8da9324

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS, which is currently cluttered with LLVM trying to align slices to start at bit 0 and failing to CSE the right shifts. PiperOrigin-RevId: 737967890
1 parent 7a459f0 commit 8da9324

File tree

4 files changed

+78
-18
lines changed

4 files changed

+78
-18
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,15 +1373,30 @@ def upcast_to_bf16(reg: ir.Value, reg_shr: ir.Value, part: int):
13731373
for group_size in (8, 4, 2):
13741374
int_ty = ir.IntegerType.get_signless(group_size * 4)
13751375
while vector_len - offset >= group_size:
1376-
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
1377-
reg_slice_int = utils.bitcast(reg_slice, int_ty)
1378-
if int_ty != i32:
1379-
reg_slice_int = arith.extsi(i32, reg_slice_int)
1380-
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
1381-
out_int_regs.extend(
1382-
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
1383-
for part in range(group_size // 2)
1384-
)
1376+
# If the vector originates from a slice (common after relayouts), we
1377+
# can fuse the slicing into the conversion and prevent LLVM from
1378+
# generating a bunch of shifts to align the vector data to the LSB.
1379+
# This also lets us share the right shift among more vectors.
1380+
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
1381+
and utils.bitwidth(slice_op.vector.type) == 32
1382+
and slice_op.strides[0].value == 1):
1383+
slice_offset = slice_op.offsets[0].value + offset
1384+
reg_int = utils.bitcast(slice_op.vector, i32)
1385+
reg_int_shr = arith.shrui(reg_int, c(4, i32))
1386+
out_int_regs.extend(
1387+
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
1388+
for part in range(group_size // 2)
1389+
)
1390+
else:
1391+
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
1392+
reg_slice_int = utils.bitcast(reg_slice, int_ty)
1393+
if int_ty != i32:
1394+
reg_slice_int = arith.extsi(i32, reg_slice_int)
1395+
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
1396+
out_int_regs.extend(
1397+
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
1398+
for part in range(group_size // 2)
1399+
)
13851400
offset += group_size
13861401
assert offset == vector_len
13871402
out_vec_int = utils.vector_concat([

jax/experimental/mosaic/gpu/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type):
346346
return ir.IntegerType(ty).width
347347
if ir.FloatType.isinstance(ty):
348348
return ir.FloatType(ty).width
349-
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
349+
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
350350
return MBARRIER_BYTES * 8
351351
if ir.VectorType.isinstance(ty):
352352
vty = ir.VectorType(ty)
@@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int):
12371237

12381238

12391239
def vector_slice(v: ir.Value, s: slice):
1240-
i32 = ir.IntegerType.get_signless(32)
12411240
v_ty = ir.VectorType(v.type)
12421241
if len(v_ty.shape) != 1:
1243-
raise NotImplementedError
1242+
raise NotImplementedError(v_ty)
12441243
[v_len] = v_ty.shape
1245-
it = range(v_len)[s]
1246-
result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type))
1247-
for tgt, src in enumerate(it):
1248-
elem = llvm.extractelement(v, c(src, i32))
1249-
result = llvm.insertelement(result, elem, c(tgt, i32))
1250-
return result
1244+
slice_length = len(range(v_len)[s])
1245+
return vector.extract_strided_slice(
1246+
ir.VectorType.get((slice_length,), v_ty.element_type),
1247+
v, [s.start or 0], [slice_length], [1],
1248+
)
12511249

12521250

12531251
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:

jaxlib/mosaic/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cc_library(
6565
"@llvm-project//mlir:Pass",
6666
"@llvm-project//mlir:Support",
6767
"@llvm-project//mlir:TransformUtils",
68+
"@llvm-project//mlir:VectorDialect",
6869
],
6970
)
7071

jaxlib/mosaic/gpu/passes.cc

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "jaxlib/mosaic/gpu/passes.h"
17+
#include <cstdint>
1718
#include <memory>
1819
#include <utility>
1920
#include <vector>
@@ -23,6 +24,7 @@ limitations under the License.
2324
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
2425
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
2526
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
27+
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
2628
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
2729
#include "mlir/include/mlir/IR/BuiltinOps.h"
2830
#include "mlir/include/mlir/IR/SymbolTable.h"
@@ -36,6 +38,49 @@ namespace gpu {
3638

3739
namespace {
3840

41+
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
42+
struct ConvertExtractStridedSlicePattern final
43+
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
44+
using OpConversionPattern::OpConversionPattern;
45+
mlir::LogicalResult matchAndRewrite(
46+
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
47+
mlir::ConversionPatternRewriter &rewriter) const override {
48+
auto vty = op.getSourceVectorType();
49+
if (vty.getRank() != 1) {
50+
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
51+
}
52+
int64_t size =
53+
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
54+
if (size < 0) {
55+
return rewriter.notifyMatchFailure(op, "size is negative");
56+
}
57+
int64_t start =
58+
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
59+
int64_t stride =
60+
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
61+
if (stride != 1) {
62+
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
63+
}
64+
if (start < 0 || start + size > vty.getShape()[0]) {
65+
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
66+
}
67+
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
68+
op.getLoc(), op.getResult().getType());
69+
for (int64_t i = 0; i < size; ++i) {
70+
result = rewriter.create<mlir::LLVM::InsertElementOp>(
71+
op.getLoc(), result,
72+
rewriter.create<mlir::LLVM::ExtractElementOp>(
73+
op.getLoc(), subst.getVector(),
74+
rewriter.create<mlir::LLVM::ConstantOp>(
75+
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
76+
rewriter.create<mlir::LLVM::ConstantOp>(
77+
op.getLoc(), rewriter.getI32IntegerAttr(i)));
78+
}
79+
rewriter.replaceOp(op, result);
80+
return mlir::success();
81+
}
82+
};
83+
3984
class ConvertGpuToLLVMPass
4085
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
4186
public:
@@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
58103
});
59104
auto symtab = mlir::SymbolTable(getOperation());
60105
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
106+
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
61107
if (mlir::applyPartialConversion(getOperation(), target,
62108
std::move(patterns))
63109
.failed()) {

0 commit comments

Comments
 (0)