Skip to content

Commit 489febe

Browse files
Google-ML-Automationjax authors
authored andcommitted
Enable input fusion for a specific kernel pattern.
cl/640530524 introduces batching support for some pallas calls that don't currently support it yet using dynamic slicing the input and dynamically updating the output. This CL ensures that XLA-guided input fusion into pallas kernel is working as expected for such pattern. We don't have support for fusion on the output side yet for pallas kernels. PiperOrigin-RevId: 641989012
1 parent f4dfa84 commit 489febe

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tests/pallas/pallas_call_tpu_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def kernel(s, x):
311311
grid=8,
312312
),
313313
interpret=self.interpret,
314+
compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])),
314315
)(s, x)
315316

316317
first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:])

0 commit comments

Comments
 (0)