Skip to content

Commit f147dd2

Browse files
author
jax authors
committed
Merge pull request #21800 from superbobry:typing
PiperOrigin-RevId: 642224964
2 parents f847350 + e8f20ad commit f147dd2

File tree

2 files changed

+39
-83
lines changed

2 files changed

+39
-83
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,7 @@ def lower_jaxpr_to_triton_module(
253253
in_shapes,
254254
grid_mapping: GridMapping,
255255
name: str,
256-
cuda_options: Any,
257256
) -> LoweringResult:
258-
# TODO(slebedev): Use cuda_options= during lowering.
259257
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
260258
with _new_ir_context(), ir.Location.unknown():
261259
module = ir.Module.create()

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
"""Module registering a lowering rule for pallas_call on GPU."""
1616

17-
# TODO(sharadmv): Enable type checking.
18-
# mypy: ignore-errors
19-
2017
from __future__ import annotations
2118

2219
import io
@@ -36,77 +33,13 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
3633
grid = (grid,)
3734
elif len(grid) > 3:
3835
raise ValueError("`grid` should have three or fewer dimensions.")
39-
return tuple(grid) + (1,) * (3 - len(grid))
36+
return tuple(grid) + (1,) * (3 - len(grid)) # type: ignore
4037

4138

4239
def avals_to_layouts(avals):
4340
return [list(reversed(range(aval.ndim))) for aval in avals]
4441

4542

46-
def _pallas_call_ttir_lowering(
47-
ctx: mlir.LoweringRuleContext,
48-
*in_nodes,
49-
jaxpr: jax_core.Jaxpr,
50-
name: str,
51-
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
52-
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
53-
debug: bool,
54-
input_output_aliases: tuple[tuple[int, int], ...],
55-
grid_mapping: pallas_core.GridMapping,
56-
triton_params: dict[str, Any] | None = None,
57-
num_warps: int,
58-
num_stages: int,
59-
):
60-
# TODO(sharadmv): Handle multiple devices with different capabilities.
61-
d, *_ = jax.local_devices(backend="gpu")
62-
cuda_options = dict(
63-
compute_capability=d.compute_capability,
64-
num_warps=num_warps,
65-
num_stages=num_stages,
66-
debug=debug,
67-
)
68-
69-
lowering_result = lowering.lower_jaxpr_to_triton_module(
70-
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options
71-
)
72-
module_op = lowering_result.module.operation
73-
if debug:
74-
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
75-
76-
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
77-
out_types = [
78-
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
79-
for shape in out_shapes
80-
]
81-
buf = io.BytesIO()
82-
module_op.write_bytecode(buf)
83-
backend_config = dict(
84-
name=ir.StringAttr.get(name),
85-
ir=ir.StringAttr.get(buf.getvalue()),
86-
num_stages=mlir.i32_attr(num_stages),
87-
num_warps=mlir.i32_attr(num_warps),
88-
grid_x=mlir.i32_attr(grid_x),
89-
grid_y=mlir.i32_attr(grid_y),
90-
grid_z=mlir.i32_attr(grid_z),
91-
debug=ir.BoolAttr.get(debug),
92-
)
93-
if "serialized_metadata" in (triton_params or {}):
94-
# This field is unstable and may be removed in the future.
95-
backend_config["serialized_metadata"] = ir.StringAttr.get(
96-
triton_params["serialized_metadata"]
97-
)
98-
return mlir.custom_call(
99-
call_target_name="__gpu$xla.gpu.triton",
100-
result_types=out_types,
101-
operands=in_nodes,
102-
backend_config=backend_config,
103-
api_version=4,
104-
operand_layouts=avals_to_layouts(ctx.avals_in),
105-
result_layouts=avals_to_layouts(ctx.avals_out),
106-
operand_output_aliases=dict(input_output_aliases),
107-
).results
108-
109-
11043
def pallas_call_lowering(
11144
ctx: mlir.LoweringRuleContext,
11245
*in_nodes,
@@ -154,17 +87,42 @@ def pallas_call_lowering(
15487
print(jaxpr)
15588
print(grid_mapping)
15689

157-
return _pallas_call_ttir_lowering(
158-
ctx,
159-
*in_nodes,
160-
jaxpr=jaxpr,
161-
name=name,
162-
in_shapes=in_shapes,
163-
out_shapes=out_shapes,
164-
debug=debug,
165-
input_output_aliases=input_output_aliases,
166-
grid_mapping=grid_mapping,
167-
triton_params=triton_params,
168-
num_warps=num_warps,
169-
num_stages=num_stages,
90+
lowering_result = lowering.lower_jaxpr_to_triton_module(
91+
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name,
92+
)
93+
module_op = lowering_result.module.operation
94+
if debug:
95+
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
96+
97+
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
98+
out_types = [
99+
ir.RankedTensorType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
100+
for shape in out_shapes
101+
]
102+
buf = io.BytesIO()
103+
module_op.write_bytecode(buf)
104+
backend_config = dict(
105+
name=ir.StringAttr.get(name),
106+
ir=ir.StringAttr.get(buf.getvalue()), # type: ignore
107+
num_stages=mlir.i32_attr(num_stages),
108+
num_warps=mlir.i32_attr(num_warps),
109+
grid_x=mlir.i32_attr(grid_x),
110+
grid_y=mlir.i32_attr(grid_y),
111+
grid_z=mlir.i32_attr(grid_z),
112+
debug=ir.BoolAttr.get(debug),
113+
)
114+
if "serialized_metadata" in (triton_params or {}):
115+
# This field is unstable and may be removed in the future.
116+
backend_config["serialized_metadata"] = ir.StringAttr.get(
117+
triton_params["serialized_metadata"]
170118
)
119+
return mlir.custom_call(
120+
call_target_name="__gpu$xla.gpu.triton",
121+
result_types=out_types,
122+
operands=in_nodes,
123+
backend_config=backend_config,
124+
api_version=4,
125+
operand_layouts=avals_to_layouts(ctx.avals_in),
126+
result_layouts=avals_to_layouts(ctx.avals_out),
127+
operand_output_aliases=dict(input_output_aliases),
128+
).results

0 commit comments

Comments
 (0)