Skip to content

Commit 22fbb0c

Browse files
Updates LLVM usage to match [355e0f94af5a](llvm/llvm-project@355e0f94af5a) PiperOrigin-RevId: 834865231
1 parent a83c167 commit 22fbb0c

File tree

6 files changed

+55
-20
lines changed

6 files changed

+55
-20
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3112,8 +3112,7 @@ def _yielded_values(outs, avals):
31123112
switch_op = scf_dialect.IndexSwitchOp(
31133113
yielded_types,
31143114
_as_index(_ensure_ir_value(index, index_aval.dtype)),
3115-
ir.DenseI64ArrayAttr.get(range(len(branches) - 1)),
3116-
num_caseRegions=len(branches) - 1,
3115+
range(len(branches) - 1),
31173116
)
31183117

31193118
# ``RegionSequence`` in MLIR does not support slicing, so the
@@ -3124,7 +3123,7 @@ def _yielded_values(outs, avals):
31243123
regions = regions[1:] + regions[:1]
31253124
treedef = None
31263125
for branch, region in zip(branches, regions):
3127-
with ir.InsertionPoint(region.blocks.append()):
3126+
with ir.InsertionPoint(region.blocks[0]):
31283127
outs = lower_jaxpr_to_mosaic_gpu(
31293128
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts
31303129
)

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ def _mgpu_arrive_expect_tx_op_lowering_rule(
13321332
barrier = utils.DialectBarrierRef.from_barrier_memref(
13331333
arrive_expect_tx_op.barrier
13341334
)
1335-
nvvm.mbarrier_arrive_expect_tx_shared(barrier.get_ptr(), bytes)
1335+
nvvm.mbarrier_arrive_expect_tx(barrier.get_ptr(), bytes)
13361336

13371337
return []
13381338

@@ -2155,15 +2155,14 @@ def _index_switch_op_lowering_rule(
21552155
_infer_flat_result_types(switch_op, out_layouts),
21562156
switch_op.arg,
21572157
switch_op.cases,
2158-
len(switch_op.regions) - 1,
21592158
)
21602159

21612160
results_template: Sequence[_VectorTemplate | None] = []
21622161
for region, new_region in zip(
21632162
switch_op.regions, new_switch_op.regions, strict=True
21642163
):
21652164
[block] = region.blocks
2166-
new_block = new_region.blocks.append()
2165+
new_block = new_region.blocks[0]
21672166
results_template = _move_scf_block_to_block_with_flattened_arguments(
21682167
ctx, block, new_block, scf.YieldOp, []
21692168
)

jax/experimental/mosaic/gpu/launch_context.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,8 +1157,10 @@ def async_copy(
11571157

11581158
if arrive:
11591159
arrive_predicate = utils.single_thread_predicate(utils.ThreadSubset.WARPGROUP)
1160-
nvvm.mbarrier_arrive_expect_tx_shared(
1161-
barrier_ptr, transfer_bytes, predicate=arrive_predicate,
1160+
nvvm.mbarrier_arrive_expect_tx(
1161+
barrier_ptr,
1162+
transfer_bytes,
1163+
predicate=arrive_predicate,
11621164
)
11631165

11641166
gmem_strides, _ = gmem_ref_ty.get_strides_and_offset()
@@ -1286,7 +1288,7 @@ def async_copy(
12861288
arith.CmpIPredicate.eq, self.cluster_idx(collective), c(0, index),
12871289
)
12881290
arrive_predicate = arith.andi(predicate, first_block)
1289-
nvvm.mbarrier_arrive_expect_tx_shared(
1291+
nvvm.mbarrier_arrive_expect_tx(
12901292
barrier_ptr, transfer_bytes, predicate=arrive_predicate
12911293
)
12921294
rank = len(slice_shape)
@@ -1307,7 +1309,7 @@ def async_copy(
13071309
)
13081310
else:
13091311
if arrive:
1310-
nvvm.mbarrier_arrive_expect_tx_shared(
1312+
nvvm.mbarrier_arrive_expect_tx(
13111313
barrier_ptr, transfer_bytes, predicate=predicate
13121314
)
13131315
if collective_size > 1:

jax/experimental/mosaic/gpu/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,7 @@ def wait_parity(self, parity, orders_tensor_core=False):
995995
i32 = ir.IntegerType.get_signless(32)
996996
ticks = arith.constant(i32, 10000000)
997997
parity = arith.extui(i32, parity)
998-
nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks)
998+
nvvm.mbarrier_try_wait_parity(self.get_ptr(), parity, ticks)
999999
if orders_tensor_core:
10001000
llvm.inline_asm(
10011001
ir.Type.parse("!llvm.void"),
@@ -1064,9 +1064,7 @@ def arrive_expect_tx(
10641064
elif ir.IndexType.isinstance(bytes.type):
10651065
i32 = ir.IntegerType.get_signless(32)
10661066
bytes = arith.index_cast(i32, bytes)
1067-
nvvm.mbarrier_arrive_expect_tx_shared(
1068-
self.get_ptr(), bytes, predicate=predicate
1069-
)
1067+
nvvm.mbarrier_arrive_expect_tx(self.get_ptr(), bytes, predicate=predicate)
10701068

10711069
def get_ptr(self):
10721070
i64 = ir.IntegerType.get_signless(64)

jaxlib/mosaic/gpu/serde.cc

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ constexpr llvm::StringRef kVersionAttrName = "stable_mosaic_gpu.version";
5050
// lowering after 2025-11-13.
5151
// TODO(apaszke): Update the forward-compatible version to 5 in Mosaic GPU
5252
// lowering after 2025-12-07.
53-
constexpr int kVersion = 5;
53+
// TODO(apaszke): Update the forward-compatible version to 6 in Mosaic GPU
54+
// lowering after 2025-12-18.
55+
constexpr int kVersion = 6;
5456

5557
using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;
5658

@@ -174,6 +176,38 @@ LogicalResult nvvm_mbarrier_init_shared_upgrade(Operation* op, int version,
174176
return success();
175177
}
176178

179+
LogicalResult nvvm_mbarrier_try_wait_parity_shared_upgrade(Operation* op,
180+
int version,
181+
bool& erased) {
182+
// https://github.com/llvm/llvm-project/commit/7eeae8e41d7827d84de12df7b5ecfab3058900cb
183+
if (version < 6) {
184+
mlir::OpBuilder b(op->getParentRegion());
185+
b.setInsertionPointAfter(op);
186+
mlir::NVVM::MBarrierTryWaitParityOp::create(
187+
b, op->getLoc(), op->getOperand(0), op->getOperand(1),
188+
op->getOperand(2));
189+
op->erase();
190+
erased = true;
191+
}
192+
return success();
193+
}
194+
195+
LogicalResult nvvm_mbarrier_arrive_expect_tx_shared_upgrade(Operation* op,
196+
int version,
197+
bool& erased) {
198+
// https://github.com/llvm/llvm-project/commit/7eeae8e41d7827d84de12df7b5ecfab3058900cb
199+
if (version < 6) {
200+
mlir::OpBuilder b(op->getParentRegion());
201+
b.setInsertionPointAfter(op);
202+
mlir::NVVM::MBarrierArriveExpectTxOp::create(
203+
b, op->getLoc(), op->getOperand(0), op->getOperand(1),
204+
op->getNumOperands() < 3 ? Value{} : op->getOperand(2));
205+
op->erase();
206+
erased = true;
207+
}
208+
return success();
209+
}
210+
177211
const llvm::StringMap<SerdeRuleType>& upgrade_rules() {
178212
static auto rules = new llvm::StringMap<SerdeRuleType>{
179213
{::llvm::StringLiteral("vector.extractelement"),
@@ -185,6 +219,10 @@ const llvm::StringMap<SerdeRuleType>& upgrade_rules() {
185219
{::llvm::StringLiteral("vector.splat"), vector_splat_upgrade},
186220
{::llvm::StringLiteral("nvvm.mbarrier.init.shared"),
187221
nvvm_mbarrier_init_shared_upgrade},
222+
{::llvm::StringLiteral("nvvm.mbarrier.try_wait.parity.shared"),
223+
nvvm_mbarrier_try_wait_parity_shared_upgrade},
224+
{::llvm::StringLiteral("nvvm.mbarrier.arrive.expect_tx.shared"),
225+
nvvm_mbarrier_arrive_expect_tx_shared_upgrade},
188226
};
189227
return *rules;
190228
}

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -509,20 +509,19 @@ def test_infer_index_switch_op_layouts(
509509
index_switch = scf.IndexSwitchOp(
510510
[out_type, out_type, f32],
511511
condition,
512-
ir.DenseI64ArrayAttr.get(range(3)),
513-
num_caseRegions=2,
512+
range(2),
514513
)
515-
with ir.InsertionPoint(index_switch.caseRegions[0].blocks.append()):
514+
with ir.InsertionPoint(index_switch.caseRegions[0].blocks[0]):
516515
out0, out1, dummy0 = undefs(out_type, out_type, f32)
517516
if out0_layout is not None:
518517
out0 = layout_cast(out0, out0_layout)
519518
yield0 = scf.YieldOp([out0, out1, dummy0])
520-
with ir.InsertionPoint(index_switch.caseRegions[1].blocks.append()):
519+
with ir.InsertionPoint(index_switch.caseRegions[1].blocks[0]):
521520
out2, out3, dummy1 = undefs(out_type, out_type, f32)
522521
if out3_layout is not None:
523522
out3 = layout_cast(out3, out3_layout)
524523
yield1 = scf.YieldOp([out2, out3, dummy1])
525-
with ir.InsertionPoint(index_switch.defaultRegion.blocks.append()):
524+
with ir.InsertionPoint(index_switch.defaultRegion.blocks[0]):
526525
out4, out5, dummy2 = undefs(out_type, out_type, f32)
527526
if out4_layout is not None:
528527
out4 = layout_cast(out4, out4_layout)

0 commit comments

Comments
 (0)