Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 4e92039

Browse files
authored
Add get mhlo module method in WarppedHlo (#925)
1 parent 28cb518 commit 4e92039

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

alpa/wrapped_hlo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Union
66

77
from jax._src.lib import xla_extension as xe
8+
from jax.interpreters import mlir
89

910

1011
class HloStatus(Enum):
@@ -38,6 +39,13 @@ def __init__(self,
3839
def get_computation(self) -> xe.XlaComputation:
3940
return xe.XlaComputation(self.module.as_serialized_hlo_module_proto())
4041

42+
def get_mhlo(self):
43+
xla_computation = self.get_computation()
44+
module_str = xe.mlir.xla_computation_to_mlir_module(xla_computation)
45+
with mlir.make_ir_context():
46+
mhlo = mlir.ir.Module.parse(module_str)
47+
return mhlo
48+
4149
def get_module(self) -> xe.HloModule:
4250
return self.module
4351

0 commit comments

Comments
 (0)