14
14
15
15
"""Module registering a lowering rule for pallas_call on GPU."""
16
16
17
- # TODO(sharadmv): Enable type checking.
18
- # mypy: ignore-errors
19
-
20
17
from __future__ import annotations
21
18
22
19
import io
@@ -36,77 +33,13 @@ def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
36
33
grid = (grid ,)
37
34
elif len (grid ) > 3 :
38
35
raise ValueError ("`grid` should have three or fewer dimensions." )
39
- return tuple (grid ) + (1 ,) * (3 - len (grid ))
36
+ return tuple (grid ) + (1 ,) * (3 - len (grid )) # type: ignore
40
37
41
38
42
39
def avals_to_layouts (avals ):
43
40
return [list (reversed (range (aval .ndim ))) for aval in avals ]
44
41
45
42
46
- def _pallas_call_ttir_lowering (
47
- ctx : mlir .LoweringRuleContext ,
48
- * in_nodes ,
49
- jaxpr : jax_core .Jaxpr ,
50
- name : str ,
51
- in_shapes : tuple [jax .ShapeDtypeStruct , ...],
52
- out_shapes : tuple [jax .ShapeDtypeStruct , ...],
53
- debug : bool ,
54
- input_output_aliases : tuple [tuple [int , int ], ...],
55
- grid_mapping : pallas_core .GridMapping ,
56
- triton_params : dict [str , Any ] | None = None ,
57
- num_warps : int ,
58
- num_stages : int ,
59
- ):
60
- # TODO(sharadmv): Handle multiple devices with different capabilities.
61
- d , * _ = jax .local_devices (backend = "gpu" )
62
- cuda_options = dict (
63
- compute_capability = d .compute_capability ,
64
- num_warps = num_warps ,
65
- num_stages = num_stages ,
66
- debug = debug ,
67
- )
68
-
69
- lowering_result = lowering .lower_jaxpr_to_triton_module (
70
- jaxpr , (* in_shapes , * out_shapes ), grid_mapping , name , cuda_options
71
- )
72
- module_op = lowering_result .module .operation
73
- if debug :
74
- print (module_op .get_asm (enable_debug_info = True , pretty_debug_info = True ))
75
-
76
- grid_x , grid_y , grid_z = normalize_grid (lowering_result .grid )
77
- out_types = [
78
- ir .RankedTensorType .get (shape .shape , mlir .dtype_to_ir_type (shape .dtype ))
79
- for shape in out_shapes
80
- ]
81
- buf = io .BytesIO ()
82
- module_op .write_bytecode (buf )
83
- backend_config = dict (
84
- name = ir .StringAttr .get (name ),
85
- ir = ir .StringAttr .get (buf .getvalue ()),
86
- num_stages = mlir .i32_attr (num_stages ),
87
- num_warps = mlir .i32_attr (num_warps ),
88
- grid_x = mlir .i32_attr (grid_x ),
89
- grid_y = mlir .i32_attr (grid_y ),
90
- grid_z = mlir .i32_attr (grid_z ),
91
- debug = ir .BoolAttr .get (debug ),
92
- )
93
- if "serialized_metadata" in (triton_params or {}):
94
- # This field is unstable and may be removed in the future.
95
- backend_config ["serialized_metadata" ] = ir .StringAttr .get (
96
- triton_params ["serialized_metadata" ]
97
- )
98
- return mlir .custom_call (
99
- call_target_name = "__gpu$xla.gpu.triton" ,
100
- result_types = out_types ,
101
- operands = in_nodes ,
102
- backend_config = backend_config ,
103
- api_version = 4 ,
104
- operand_layouts = avals_to_layouts (ctx .avals_in ),
105
- result_layouts = avals_to_layouts (ctx .avals_out ),
106
- operand_output_aliases = dict (input_output_aliases ),
107
- ).results
108
-
109
-
110
43
def pallas_call_lowering (
111
44
ctx : mlir .LoweringRuleContext ,
112
45
* in_nodes ,
@@ -154,17 +87,42 @@ def pallas_call_lowering(
154
87
print (jaxpr )
155
88
print (grid_mapping )
156
89
157
- return _pallas_call_ttir_lowering (
158
- ctx ,
159
- * in_nodes ,
160
- jaxpr = jaxpr ,
161
- name = name ,
162
- in_shapes = in_shapes ,
163
- out_shapes = out_shapes ,
164
- debug = debug ,
165
- input_output_aliases = input_output_aliases ,
166
- grid_mapping = grid_mapping ,
167
- triton_params = triton_params ,
168
- num_warps = num_warps ,
169
- num_stages = num_stages ,
90
+ lowering_result = lowering .lower_jaxpr_to_triton_module (
91
+ jaxpr , (* in_shapes , * out_shapes ), grid_mapping , name ,
92
+ )
93
+ module_op = lowering_result .module .operation
94
+ if debug :
95
+ print (module_op .get_asm (enable_debug_info = True , pretty_debug_info = True ))
96
+
97
+ grid_x , grid_y , grid_z = normalize_grid (lowering_result .grid )
98
+ out_types = [
99
+ ir .RankedTensorType .get (shape .shape , mlir .dtype_to_ir_type (shape .dtype ))
100
+ for shape in out_shapes
101
+ ]
102
+ buf = io .BytesIO ()
103
+ module_op .write_bytecode (buf )
104
+ backend_config = dict (
105
+ name = ir .StringAttr .get (name ),
106
+ ir = ir .StringAttr .get (buf .getvalue ()), # type: ignore
107
+ num_stages = mlir .i32_attr (num_stages ),
108
+ num_warps = mlir .i32_attr (num_warps ),
109
+ grid_x = mlir .i32_attr (grid_x ),
110
+ grid_y = mlir .i32_attr (grid_y ),
111
+ grid_z = mlir .i32_attr (grid_z ),
112
+ debug = ir .BoolAttr .get (debug ),
113
+ )
114
+ if "serialized_metadata" in (triton_params or {}):
115
+ # This field is unstable and may be removed in the future.
116
+ backend_config ["serialized_metadata" ] = ir .StringAttr .get (
117
+ triton_params ["serialized_metadata" ]
170
118
)
119
+ return mlir .custom_call (
120
+ call_target_name = "__gpu$xla.gpu.triton" ,
121
+ result_types = out_types ,
122
+ operands = in_nodes ,
123
+ backend_config = backend_config ,
124
+ api_version = 4 ,
125
+ operand_layouts = avals_to_layouts (ctx .avals_in ),
126
+ result_layouts = avals_to_layouts (ctx .avals_out ),
127
+ operand_output_aliases = dict (input_output_aliases ),
128
+ ).results
0 commit comments