Skip to content

Commit b1aa1d9

Browse files
authored
[Wave] Migrate iree.turbine.runtime -> wave_lang.runtime (#49)
Signed-off-by: Harsh Menon <[email protected]>
1 parent cd0b5d4 commit b1aa1d9

File tree

12 files changed

+34
-53
lines changed

12 files changed

+34
-53
lines changed

docs/core/runtime.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
`iree.turbine.runtime`
1+
`wave_lang.runtime`
22
======================
33

4-
.. automodule:: iree.turbine.runtime
4+
.. automodule:: wave_lang.runtime
55
:imported-members:
66
:members:
77
:undoc-members:
88

99
op_reg
1010
--------------
1111

12-
.. automodule:: iree.turbine.runtime.op_reg
12+
.. automodule:: wave_lang.runtime.op_reg
1313
:imported-members:
1414
:members:
1515
:undoc-members:

iree/turbine/kernel/gen/kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
IrType,
3131
)
3232

33-
from ...runtime.op_reg import (
33+
from wave_lang.runtime.op_reg import (
3434
def_library,
3535
CustomOp,
3636
KernelBuilder,

iree/turbine/kernel/wave/iree_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
from wave_lang.support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM
9-
from iree.turbine.runtime.launch import Launchable
9+
from wave_lang.runtime.launch import Launchable
1010

1111

1212
def get_chain_mmt_asm(

lit_tests/kernel/wave/sharktank_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
from iree.turbine.kernel.wave.utils.run_utils import (
3636
set_default_run_config,
3737
)
38-
from iree.turbine.runtime.op_reg.base import (
38+
from wave_lang.runtime.op_reg.base import (
3939
CustomOp,
4040
KernelBuilder,
4141
KernelSelection,
4242
)
43-
from iree.turbine.runtime.op_reg.impl_helper import (
43+
from wave_lang.runtime.op_reg.impl_helper import (
4444
call_function,
4545
)
4646
from wave_lang.transforms.merger import Merger

iree/turbine/runtime/__init__.py renamed to wave_lang/runtime/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@
77
from .device import *
88
from .invoke import *
99
from .launch import *
10-
from . import op_reg

iree/turbine/runtime/device.py renamed to wave_lang/runtime/device.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,30 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import atexit
8+
import ctypes
9+
import platform
710
from functools import lru_cache
811
from hashlib import sha1
12+
from threading import Lock, local
913
from typing import Any, Callable, Dict, Optional, Union
10-
from threading import local, Lock
11-
12-
import warnings
13-
import platform
14-
import atexit
1514

16-
import ctypes
1715
import torch
1816

1917
from iree.runtime import (
2018
BufferUsage,
19+
ExternalTimepointFlags,
2120
HalBufferView,
2221
HalDevice,
2322
HalDriver,
2423
HalExternalTimepoint,
2524
MemoryType,
25+
SemaphoreCompatibility,
2626
VmInstance,
2727
VmModule,
28-
SemaphoreCompatibility,
29-
ExternalTimepointFlags,
3028
create_hal_module,
3129
get_driver,
3230
)
33-
3431
from wave_lang.support.conversions import (
3532
dtype_to_element_type,
3633
torch_dtype_to_numpy,

iree/turbine/runtime/invoke.py renamed to wave_lang/runtime/invoke.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,17 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from typing import (
8-
Any,
98
Callable,
10-
Sequence,
119
)
1210

13-
from .device import Device
14-
1511
from iree.runtime import (
12+
HalFence,
1613
VmContext,
1714
VmFunction,
18-
HalFence,
1915
VmVariantList,
2016
)
2117

18+
from .device import Device
2219

2320
__all__ = [
2421
"invoke_vm_function",
@@ -33,7 +30,7 @@ def invoke_vm_function(
3330
arg_list: VmVariantList,
3431
ret_list: VmVariantList,
3532
*,
36-
timer: Callable[[], float] = (lambda: 0.0)
33+
timer: Callable[[], float] = (lambda: 0.0),
3734
):
3835
"""Invokes a vm function on a device, adding async fences to the arg_list if is_async.
3936

iree/turbine/runtime/launch.py renamed to wave_lang/runtime/launch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,27 @@
1111
from torch import Tensor
1212

1313
from iree.compiler.api import (
14+
Output,
1415
Session,
1516
Source,
16-
Output,
1717
)
18-
1918
from iree.runtime import (
20-
create_io_parameters_module,
2119
HalBufferView,
2220
HalElementType,
23-
HalFence,
2421
ParameterProvider,
2522
VmContext,
2623
VmFunction,
2724
VmModule,
2825
VmRef,
2926
VmVariantList,
27+
create_io_parameters_module,
3028
)
31-
3229
from wave_lang.support.logging import runtime_logger as logger
3330

3431
from .device import (
35-
get_device_from_torch,
3632
Device,
33+
get_device_from_torch,
3734
)
38-
3935
from .invoke import invoke_vm_function
4036

4137
__all__ = [

iree/turbine/runtime/op_reg/__init__.py renamed to wave_lang/runtime/op_reg/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from .base import *
8-
from . import impl_helper

iree/turbine/runtime/op_reg/base.py renamed to wave_lang/runtime/op_reg/base.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,38 @@
88
dispatcher.
99
"""
1010

11-
from typing import Any, Callable, List, Optional, Sequence, Type, Union, cast
12-
13-
from abc import ABC, abstractmethod
1411
import functools
1512
import logging
1613
import re
1714
import textwrap
1815
import threading
16+
from abc import ABC, abstractmethod
17+
from typing import Any, Callable, Optional, Sequence, Type, Union, cast
1918

2019
import torch
2120
from torch import Tensor
2221

22+
from wave_lang.support.conversions import (
23+
TORCH_DTYPE_TO_IREE_TYPE_ASM,
24+
)
2325
from wave_lang.support.ir_imports import (
2426
Block,
2527
Context,
2628
FunctionType,
2729
IndexType,
2830
InsertionPoint,
2931
IntegerAttr,
32+
IrType,
3033
Location,
3134
StringAttr,
3235
SymbolTable,
33-
IrType,
3436
Value,
3537
arith_d,
3638
builtin_d,
3739
func_d,
3840
)
39-
4041
from wave_lang.support.logging import runtime_logger as logger
4142

42-
from wave_lang.support.conversions import (
43-
TORCH_DTYPE_TO_IREE_TYPE_ASM,
44-
)
45-
4643
__all__ = [
4744
"ArgDescriptor",
4845
"AttrArg",

0 commit comments

Comments
 (0)