Skip to content

Commit 5df0e49

Browse files
oulgenpytorchmergebot
authored andcommitted
[pallas backend] implement complex numbers (pytorch#167947)
Pull Request resolved: pytorch#167947 Approved by: https://github.com/jansel
1 parent e5e94ec commit 5df0e49

File tree

3 files changed

+177
-25
lines changed

3 files changed

+177
-25
lines changed

test/inductor/test_pallas.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,112 @@ def fn(x, row_indices):
458458
expected = fn(x, row_indices)
459459
self.assertEqual(result, expected)
460460

461+
def test_complex64_mul(self):
462+
"""Test complex64 multiplication."""
463+
464+
def fn(a, b):
465+
return a * b
466+
467+
compiled = self._compile(fn)
468+
469+
a = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
470+
b = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
471+
result = compiled(a, b)
472+
expected = fn(a, b)
473+
self.assertEqual(result, expected)
474+
475+
def test_complex_conj(self):
476+
"""Test complex conjugate."""
477+
478+
def fn(x):
479+
return torch.conj(x)
480+
481+
compiled = self._compile(fn)
482+
483+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
484+
result = compiled(x)
485+
expected = fn(x)
486+
self.assertEqual(result, expected)
487+
488+
def test_complex_real(self):
489+
"""Test extracting real part of complex tensor."""
490+
491+
def fn(x):
492+
return torch.real(x)
493+
494+
compiled = self._compile(fn)
495+
496+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
497+
result = compiled(x)
498+
expected = fn(x)
499+
self.assertEqual(result, expected)
500+
501+
def test_complex_imag(self):
502+
"""Test extracting imaginary part of complex tensor."""
503+
504+
def fn(x):
505+
return torch.imag(x)
506+
507+
compiled = self._compile(fn)
508+
509+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
510+
result = compiled(x)
511+
expected = fn(x)
512+
self.assertEqual(result, expected)
513+
514+
def test_complex_abs(self):
515+
"""Test complex absolute value (magnitude)."""
516+
517+
def fn(x):
518+
return torch.abs(x)
519+
520+
compiled = self._compile(fn)
521+
522+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
523+
result = compiled(x)
524+
expected = fn(x)
525+
self.assertEqual(result, expected)
526+
527+
def test_complex128_conj(self):
528+
"""Test complex128 conjugate operation."""
529+
530+
def fn(x):
531+
return torch.conj(x)
532+
533+
compiled = self._compile(fn)
534+
535+
x = torch.randn(16, dtype=torch.complex128, device=self.DEVICE)
536+
result = compiled(x)
537+
expected = fn(x)
538+
self.assertEqual(result, expected)
539+
540+
def test_complex_mul_scalar(self):
541+
"""Test complex multiplication with scalar."""
542+
543+
def fn(x):
544+
return x * 2.5
545+
546+
compiled = self._compile(fn)
547+
548+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
549+
result = compiled(x)
550+
expected = fn(x)
551+
self.assertEqual(result, expected)
552+
553+
def test_complex_conj_mul(self):
554+
"""Test conjugate followed by multiplication."""
555+
556+
def fn(x, y):
557+
return torch.conj(x) * y
558+
559+
compiled = self._compile(fn)
560+
561+
x = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
562+
y = torch.randn(16, dtype=torch.complex64, device=self.DEVICE)
563+
result = compiled(x, y)
564+
expected = fn(x, y)
565+
self.assertEqual(result, expected)
566+
461567

462568
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
463569
class PallasTestsCUDA(PallasTestsMixin, TestCase):

torch/_inductor/codegen/pallas.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,32 @@ def constant(val, dtype: torch.dtype) -> str:
218218
return "True" if val else "False"
219219
return f"jnp.array({val}, dtype={jax_dtype})"
220220

221+
@staticmethod
222+
def real(x: str) -> str:
223+
return f"jnp.real({x})"
224+
225+
@staticmethod
226+
def imag(x: str) -> str:
227+
return f"jnp.imag({x})"
228+
229+
@staticmethod
230+
def conj(x: str) -> str:
231+
return f"jnp.conj({x})"
232+
233+
@staticmethod
234+
def angle(x: str) -> str:
235+
return f"jnp.angle({x})"
236+
237+
@staticmethod
238+
def view_as_real(x: str) -> str:
239+
"""View complex tensor as real tensor with extra dimension."""
240+
return f"jnp.stack([jnp.real({x}), jnp.imag({x})], axis=-1)"
241+
242+
@staticmethod
243+
def view_as_complex(x: str) -> str:
244+
"""View real tensor as complex tensor."""
245+
return f"({x}[..., 0] + 1j * {x}[..., 1])"
246+
221247

222248
class PallasKernel(SIMDKernel):
223249
"""
@@ -642,6 +668,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove
642668
import jax
643669
import jax.numpy as jnp
644670
from jax.experimental import pallas as pl
671+
from torch._inductor.runtime.runtime_utils import torch_dtype_to_jax_runtime
645672
"""
646673
+ (
647674
"\n from jax.experimental.pallas import triton as pltriton"
@@ -822,16 +849,6 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove
822849
)
823850

824851
code.writeline("# Prepare output metadata from PyTorch tensor")
825-
code.writeline("# Map PyTorch dtype to JAX dtype")
826-
code.writeline("_torch_dtype_to_jax = {")
827-
code.writeline(
828-
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
829-
)
830-
code.writeline(
831-
" torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,"
832-
)
833-
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
834-
code.writeline("}")
835852
code.writeline(
836853
"out_shapes = ("
837854
+ ", ".join([f"tuple({name}.shape)" for name in output_params])
@@ -840,7 +857,10 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove
840857
code.writeline(
841858
"out_dtypes = ("
842859
+ ", ".join(
843-
[f"_torch_dtype_to_jax[{name}.dtype]" for name in output_params]
860+
[
861+
f"torch_dtype_to_jax_runtime({name}.dtype)"
862+
for name in output_params
863+
]
844864
)
845865
+ ",)"
846866
)

torch/_inductor/runtime/runtime_utils.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,43 @@ def compile_mps_shader(source: str) -> Any:
189189
raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
190190

191191

192+
def torch_dtype_to_jax_runtime(dtype: torch.dtype) -> Any:
193+
"""
194+
Map PyTorch dtype to actual JAX dtype object at runtime.
195+
196+
This helper is used in generated Pallas kernels at runtime to convert
197+
PyTorch dtypes to JAX dtype objects (not string representations).
198+
199+
Args:
200+
dtype: PyTorch dtype to convert
201+
202+
Returns:
203+
JAX dtype object (e.g., jnp.float32 object itself)
204+
"""
205+
import jax.numpy as jnp # pyrefly: ignore [import-error]
206+
207+
dtype_map = {
208+
torch.float32: jnp.float32,
209+
torch.float64: jnp.float64,
210+
torch.float16: jnp.float16,
211+
torch.bfloat16: jnp.bfloat16,
212+
torch.int32: jnp.int32,
213+
torch.int64: jnp.int64,
214+
torch.int16: jnp.int16,
215+
torch.int8: jnp.int8,
216+
torch.uint8: jnp.uint8,
217+
torch.bool: jnp.bool_,
218+
torch.complex64: jnp.complex64,
219+
torch.complex128: jnp.complex128,
220+
}
221+
if dtype not in dtype_map:
222+
raise ValueError(f"Unsupported dtype for JAX conversion: {dtype}")
223+
return dtype_map[dtype]
224+
225+
192226
def torch_dtype_to_jax(dtype: torch.dtype) -> str:
193227
"""
194-
Map PyTorch dtype to JAX dtype expression.
228+
Map PyTorch dtype to JAX dtype expression string.
195229
196230
This helper is used at compile time in codegen to generate
197231
JAX dtype expressions for Pallas kernels.
@@ -202,16 +236,8 @@ def torch_dtype_to_jax(dtype: torch.dtype) -> str:
202236
Returns:
203237
JAX dtype expression as string (e.g., "jnp.float32")
204238
"""
205-
dtype_map = {
206-
torch.float32: "jnp.float32",
207-
torch.float64: "jnp.float64",
208-
torch.float16: "jnp.float16",
209-
torch.bfloat16: "jnp.bfloat16",
210-
torch.int32: "jnp.int32",
211-
torch.int64: "jnp.int64",
212-
torch.int16: "jnp.int16",
213-
torch.int8: "jnp.int8",
214-
torch.uint8: "jnp.uint8",
215-
torch.bool: "jnp.bool_",
216-
}
217-
return dtype_map.get(dtype, f"jnp.{dtype}")
239+
jax_dtype = torch_dtype_to_jax_runtime(dtype)
240+
dtype_name = jax_dtype.__name__
241+
if dtype_name == "bool":
242+
dtype_name = "bool_"
243+
return f"jnp.{dtype_name}"

0 commit comments

Comments
 (0)