Skip to content

Commit 9ff1922

Browse files
oulgenpytorchmergebot
authored andcommitted
[pallas backend] implement more ops (pytorch#167951)
Pull Request resolved: pytorch#167951 Approved by: https://github.com/jansel ghstack dependencies: pytorch#167947
1 parent 5df0e49 commit 9ff1922

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

test/inductor/test_pallas.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,132 @@ def fn(x, y):
564564
expected = fn(x, y)
565565
self.assertEqual(result, expected)
566566

567+
def test_where(self):
568+
"""Test torch.where operation."""
569+
570+
def fn(x, y):
571+
return torch.where(x > 0, x, y)
572+
573+
compiled = self._compile(fn)
574+
575+
x = torch.randn(16, device=self.DEVICE)
576+
y = torch.randn(16, device=self.DEVICE)
577+
result = compiled(x, y)
578+
expected = fn(x, y)
579+
self.assertEqual(result, expected)
580+
581+
def test_clamp(self):
582+
"""Test torch.clamp operation."""
583+
584+
def fn(x):
585+
return torch.clamp(x, -1.0, 1.0)
586+
587+
compiled = self._compile(fn)
588+
589+
x = torch.randn(16, device=self.DEVICE) * 2
590+
result = compiled(x)
591+
expected = fn(x)
592+
self.assertEqual(result, expected)
593+
594+
def test_comparison_ops(self):
595+
"""Test comparison operations."""
596+
597+
def fn(a, b):
598+
gt = a > b
599+
lt = a < b
600+
eq = a == b
601+
return gt.float() + lt.float() + eq.float()
602+
603+
compiled = self._compile(fn)
604+
605+
a = torch.randn(16, device=self.DEVICE)
606+
b = torch.randn(16, device=self.DEVICE)
607+
result = compiled(a, b)
608+
expected = fn(a, b)
609+
self.assertEqual(result, expected)
610+
611+
def test_logical_ops(self):
612+
"""Test logical operations."""
613+
614+
def fn(a, b):
615+
return torch.logical_and(a > 0, b > 0).float()
616+
617+
compiled = self._compile(fn)
618+
619+
a = torch.randn(16, device=self.DEVICE)
620+
b = torch.randn(16, device=self.DEVICE)
621+
result = compiled(a, b)
622+
expected = fn(a, b)
623+
self.assertEqual(result, expected)
624+
625+
def test_sign(self):
626+
"""Test sign operation."""
627+
628+
def fn(x):
629+
return torch.sign(x)
630+
631+
compiled = self._compile(fn)
632+
633+
x = torch.randn(16, device=self.DEVICE)
634+
result = compiled(x)
635+
expected = fn(x)
636+
self.assertEqual(result, expected)
637+
638+
def test_reciprocal(self):
639+
"""Test reciprocal operation."""
640+
641+
def fn(x):
642+
return torch.reciprocal(x)
643+
644+
compiled = self._compile(fn)
645+
646+
x = torch.randn(16, device=self.DEVICE) + 1.0 # Avoid zeros
647+
result = compiled(x)
648+
expected = fn(x)
649+
self.assertEqual(result, expected)
650+
651+
def test_square(self):
652+
"""Test square operation."""
653+
654+
def fn(x):
655+
return torch.square(x)
656+
657+
compiled = self._compile(fn)
658+
659+
x = torch.randn(16, device=self.DEVICE)
660+
result = compiled(x)
661+
expected = fn(x)
662+
self.assertEqual(result, expected)
663+
664+
def test_erf(self):
665+
"""Test erf operation."""
666+
if self.DEVICE == "cuda":
667+
self.skipTest("erf not supported in Pallas GPU (Triton) backend")
668+
669+
def fn(x):
670+
return torch.erf(x)
671+
672+
compiled = self._compile(fn)
673+
674+
x = torch.randn(16, device=self.DEVICE)
675+
result = compiled(x)
676+
expected = fn(x)
677+
self.assertEqual(result, expected)
678+
679+
def test_atan2(self):
680+
"""Test atan2 operation."""
681+
682+
def fn(a, b):
683+
return torch.atan2(a, b)
684+
685+
compiled = self._compile(fn)
686+
687+
a = torch.randn(16, device=self.DEVICE)
688+
b = torch.randn(16, device=self.DEVICE)
689+
result = compiled(a, b)
690+
expected = fn(a, b)
691+
self.assertEqual(result, expected)
692+
567693

568694
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
569695
class PallasTestsCUDA(PallasTestsMixin, TestCase):

torch/_inductor/codegen/pallas.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,163 @@ def view_as_complex(x: str) -> str:
244244
"""View real tensor as complex tensor."""
245245
return f"({x}[..., 0] + 1j * {x}[..., 1])"
246246

247+
# Comparison operations
248+
@staticmethod
249+
def eq(a: str, b: str) -> str:
250+
return f"({a} == {b})"
251+
252+
@staticmethod
253+
def ne(a: str, b: str) -> str:
254+
return f"({a} != {b})"
255+
256+
@staticmethod
257+
def lt(a: str, b: str) -> str:
258+
return f"({a} < {b})"
259+
260+
@staticmethod
261+
def le(a: str, b: str) -> str:
262+
return f"({a} <= {b})"
263+
264+
@staticmethod
265+
def gt(a: str, b: str) -> str:
266+
return f"({a} > {b})"
267+
268+
@staticmethod
269+
def ge(a: str, b: str) -> str:
270+
return f"({a} >= {b})"
271+
272+
# Logical operations
273+
@staticmethod
274+
def logical_and(a: str, b: str) -> str:
275+
return f"jnp.logical_and({a}, {b})"
276+
277+
@staticmethod
278+
def logical_or(a: str, b: str) -> str:
279+
return f"jnp.logical_or({a}, {b})"
280+
281+
@staticmethod
282+
def logical_not(x: str) -> str:
283+
return f"jnp.logical_not({x})"
284+
285+
@staticmethod
286+
def logical_xor(a: str, b: str) -> str:
287+
return f"jnp.logical_xor({a}, {b})"
288+
289+
# Math operations
290+
@staticmethod
291+
def atan2(a: str, b: str) -> str:
292+
return f"jnp.arctan2({a}, {b})"
293+
294+
@staticmethod
295+
def hypot(a: str, b: str) -> str:
296+
return f"jnp.hypot({a}, {b})"
297+
298+
@staticmethod
299+
def fmod(a: str, b: str) -> str:
300+
return f"jnp.fmod({a}, {b})"
301+
302+
@staticmethod
303+
def remainder(a: str, b: str) -> str:
304+
return f"jnp.remainder({a}, {b})"
305+
306+
@staticmethod
307+
def clamp(x: str, min_val: str, max_val: str) -> str:
308+
return f"jnp.clip({x}, {min_val}, {max_val})"
309+
310+
@staticmethod
311+
def clip(x: str, min_val: str, max_val: str) -> str:
312+
return f"jnp.clip({x}, {min_val}, {max_val})"
313+
314+
# Sign operations
315+
@staticmethod
316+
def sign(x: str) -> str:
317+
return f"jnp.sign({x})"
318+
319+
@staticmethod
320+
def signbit(x: str) -> str:
321+
return f"jnp.signbit({x})"
322+
323+
# Special math functions
324+
@staticmethod
325+
def erf(x: str) -> str:
326+
return f"jax.scipy.special.erf({x})"
327+
328+
@staticmethod
329+
def erfc(x: str) -> str:
330+
return f"jax.scipy.special.erfc({x})"
331+
332+
@staticmethod
333+
def erfinv(x: str) -> str:
334+
return f"jax.scipy.special.erfinv({x})"
335+
336+
@staticmethod
337+
def lgamma(x: str) -> str:
338+
return f"jax.scipy.special.gammaln({x})"
339+
340+
@staticmethod
341+
def digamma(x: str) -> str:
342+
return f"jax.scipy.special.digamma({x})"
343+
344+
# Reciprocal and square
345+
@staticmethod
346+
def reciprocal(x: str) -> str:
347+
return f"jnp.reciprocal({x})"
348+
349+
@staticmethod
350+
def square(x: str) -> str:
351+
return f"jnp.square({x})"
352+
353+
# Additional operations
354+
@staticmethod
355+
def fma(a: str, b: str, c: str) -> str:
356+
"""Fused multiply-add: a * b + c"""
357+
return f"jnp.fma({a}, {b}, {c})"
358+
359+
@staticmethod
360+
def copysign(a: str, b: str) -> str:
361+
return f"jnp.copysign({a}, {b})"
362+
363+
@staticmethod
364+
def nextafter(a: str, b: str) -> str:
365+
return f"jnp.nextafter({a}, {b})"
366+
367+
@staticmethod
368+
def ldexp(a: str, b: str) -> str:
369+
return f"jnp.ldexp({a}, {b})"
370+
371+
@staticmethod
372+
def frexp(x: str) -> str:
373+
return f"jnp.frexp({x})"
374+
375+
@staticmethod
376+
def modf(x: str) -> str:
377+
return f"jnp.modf({x})"
378+
379+
# Bitwise operations
380+
@staticmethod
381+
def bitwise_and(a: str, b: str) -> str:
382+
return f"jnp.bitwise_and({a}, {b})"
383+
384+
@staticmethod
385+
def bitwise_or(a: str, b: str) -> str:
386+
return f"jnp.bitwise_or({a}, {b})"
387+
388+
@staticmethod
389+
def bitwise_xor(a: str, b: str) -> str:
390+
return f"jnp.bitwise_xor({a}, {b})"
391+
392+
@staticmethod
393+
def bitwise_not(x: str) -> str:
394+
return f"jnp.bitwise_not({x})"
395+
396+
@staticmethod
397+
def left_shift(a: str, b: str) -> str:
398+
return f"jnp.left_shift({a}, {b})"
399+
400+
@staticmethod
401+
def right_shift(a: str, b: str) -> str:
402+
return f"jnp.right_shift({a}, {b})"
403+
247404

248405
class PallasKernel(SIMDKernel):
249406
"""

0 commit comments

Comments
 (0)