Skip to content

Commit 542bb44

Browse files
authored
Replace CRLF with LF in wave_conv_utils.py (#69)
For some reason, this file (and only this file) had been checked into git in CRLF format, which is not good practice.
1 parent a597b1f commit 542bb44

File tree

1 file changed

+116
-116
lines changed

1 file changed

+116
-116
lines changed
Lines changed: 116 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,116 @@
1-
from ..utils import *
2-
from dataclasses import dataclass
3-
from pathlib import Path
4-
from typing import Optional
5-
from .conv_utils import ConvConfig
6-
import traceback
7-
8-
try:
9-
import iree.turbine.kernel as tk
10-
import iree.turbine.kernel.lang as tkl
11-
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
12-
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
13-
from iree.turbine.kernel.wave.scheduling.schedule_enums import SchedulingType
14-
from iree.turbine.kernel.wave.utils.torch_utils import (
15-
device_randn,
16-
device_randint,
17-
device_randperm,
18-
device_zeros,
19-
)
20-
except ImportError as e:
21-
TURBINE_AVAILABLE = False
22-
turbine_import_error = e
23-
else:
24-
TURBINE_AVAILABLE = True
25-
26-
27-
def compile_wave_conv_config(
28-
tag: str,
29-
config: ConvConfig,
30-
kernel_dir: Path,
31-
vmfb_dir: Path,
32-
extra_compiler_args: list[str],
33-
) -> tuple[Path, Optional[Path]]:
34-
if not TURBINE_AVAILABLE:
35-
raise ValueError(
36-
f"Can't compile TK benchmark because of a failed import (most likely iree.turbine is missing): {turbine_import_error}"
37-
)
38-
39-
# Name with tag is used for filenames so that duplicate configs with
40-
# different tags will not clobber eachother.
41-
name_with_tag = tag + "-" + config.get_name()
42-
mlir_file = kernel_dir / (name_with_tag + ".mlir")
43-
vmfb_file = vmfb_dir / (name_with_tag + ".vmfb")
44-
files_path = vmfb_dir / name_with_tag
45-
46-
try:
47-
_compile_conv(config, mlir_file, vmfb_file)
48-
except Exception as e:
49-
error_file = vmfb_dir / (config.get_name() + "_error.txt")
50-
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
51-
with open(error_file, "w") as f:
52-
f.write(str(e))
53-
f.write(traceback.format_exc())
54-
return mlir_file, None, None
55-
56-
return mlir_file, vmfb_file, files_path
57-
58-
59-
def _decode_op(op: str) -> tuple[str, str]:
60-
if op.startswith("conv_2d_"):
61-
return "conv_2d", op[len("conv_2d_") :]
62-
63-
raise ValueError(f"Unsupported op: {op}")
64-
65-
66-
def _convert_dtype(dtype: str):
67-
dtypes = {
68-
"i8": tkl.i8,
69-
"i16": tkl.i16,
70-
"i32": tkl.i32,
71-
"i64": tkl.i64,
72-
"f16": tkl.f16,
73-
"f32": tkl.f32,
74-
"f64": tkl.f64,
75-
"bf16": tkl.bf16,
76-
}
77-
return dtypes[dtype]
78-
79-
80-
def _compile_conv(config: ConvConfig, mlir_file: Path, vmfb_file: Path):
81-
print("Compile TKW kernel", config.OP)
82-
op_type, layout = _decode_op(config.OP)
83-
84-
in_h = config.H * config.S + config.P - 1
85-
in_w = config.W * config.S + config.Q - 1
86-
if op_type == "conv_2d":
87-
conv, hyperparams = get_igemm_conv2d(
88-
layout=layout,
89-
n=config.N,
90-
h=in_h,
91-
w=in_w,
92-
c=config.C,
93-
hf=config.P,
94-
wf=config.Q,
95-
nf=config.F,
96-
stride=config.S,
97-
input_dtype=_convert_dtype(config.input_dtype),
98-
output_dtype=_convert_dtype(config.output_dtype),
99-
)
100-
else:
101-
raise ValueError(f"Unsupported op_type: {op_type}")
102-
103-
options = WaveCompileOptions(
104-
subs=hyperparams,
105-
canonicalize=True,
106-
create_vmfb_file=vmfb_file,
107-
schedule=SchedulingType.NONE,
108-
# inline=False, (TODO: how to do this with new API?)
109-
backend="rocm",
110-
target="gfx942",
111-
)
112-
result = wave_compile(options, conv)
113-
with open(mlir_file, "w") as f:
114-
f.write(result.asm)
115-
116-
print(f"Successfully compiled to {vmfb_file}")
1+
from ..utils import *
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Optional
5+
from .conv_utils import ConvConfig
6+
import traceback
7+
8+
try:
9+
import iree.turbine.kernel as tk
10+
import iree.turbine.kernel.lang as tkl
11+
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
12+
from iree.turbine.kernel.wave.compile import wave_compile, WaveCompileOptions
13+
from iree.turbine.kernel.wave.scheduling.schedule_enums import SchedulingType
14+
from iree.turbine.kernel.wave.utils.torch_utils import (
15+
device_randn,
16+
device_randint,
17+
device_randperm,
18+
device_zeros,
19+
)
20+
except ImportError as e:
21+
TURBINE_AVAILABLE = False
22+
turbine_import_error = e
23+
else:
24+
TURBINE_AVAILABLE = True
25+
26+
27+
def compile_wave_conv_config(
28+
tag: str,
29+
config: ConvConfig,
30+
kernel_dir: Path,
31+
vmfb_dir: Path,
32+
extra_compiler_args: list[str],
33+
) -> tuple[Path, Optional[Path]]:
34+
if not TURBINE_AVAILABLE:
35+
raise ValueError(
36+
f"Can't compile TK benchmark because of a failed import (most likely iree.turbine is missing): {turbine_import_error}"
37+
)
38+
39+
# Name with tag is used for filenames so that duplicate configs with
40+
# different tags will not clobber eachother.
41+
name_with_tag = tag + "-" + config.get_name()
42+
mlir_file = kernel_dir / (name_with_tag + ".mlir")
43+
vmfb_file = vmfb_dir / (name_with_tag + ".vmfb")
44+
files_path = vmfb_dir / name_with_tag
45+
46+
try:
47+
_compile_conv(config, mlir_file, vmfb_file)
48+
except Exception as e:
49+
error_file = vmfb_dir / (config.get_name() + "_error.txt")
50+
print(f"Failed to compile {config.get_name()}. Error dumped in {error_file}")
51+
with open(error_file, "w") as f:
52+
f.write(str(e))
53+
f.write(traceback.format_exc())
54+
return mlir_file, None, None
55+
56+
return mlir_file, vmfb_file, files_path
57+
58+
59+
def _decode_op(op: str) -> tuple[str, str]:
60+
if op.startswith("conv_2d_"):
61+
return "conv_2d", op[len("conv_2d_") :]
62+
63+
raise ValueError(f"Unsupported op: {op}")
64+
65+
66+
def _convert_dtype(dtype: str):
67+
dtypes = {
68+
"i8": tkl.i8,
69+
"i16": tkl.i16,
70+
"i32": tkl.i32,
71+
"i64": tkl.i64,
72+
"f16": tkl.f16,
73+
"f32": tkl.f32,
74+
"f64": tkl.f64,
75+
"bf16": tkl.bf16,
76+
}
77+
return dtypes[dtype]
78+
79+
80+
def _compile_conv(config: ConvConfig, mlir_file: Path, vmfb_file: Path):
81+
print("Compile TKW kernel", config.OP)
82+
op_type, layout = _decode_op(config.OP)
83+
84+
in_h = config.H * config.S + config.P - 1
85+
in_w = config.W * config.S + config.Q - 1
86+
if op_type == "conv_2d":
87+
conv, hyperparams = get_igemm_conv2d(
88+
layout=layout,
89+
n=config.N,
90+
h=in_h,
91+
w=in_w,
92+
c=config.C,
93+
hf=config.P,
94+
wf=config.Q,
95+
nf=config.F,
96+
stride=config.S,
97+
input_dtype=_convert_dtype(config.input_dtype),
98+
output_dtype=_convert_dtype(config.output_dtype),
99+
)
100+
else:
101+
raise ValueError(f"Unsupported op_type: {op_type}")
102+
103+
options = WaveCompileOptions(
104+
subs=hyperparams,
105+
canonicalize=True,
106+
create_vmfb_file=vmfb_file,
107+
schedule=SchedulingType.NONE,
108+
# inline=False, (TODO: how to do this with new API?)
109+
backend="rocm",
110+
target="gfx942",
111+
)
112+
result = wave_compile(options, conv)
113+
with open(mlir_file, "w") as f:
114+
f.write(result.asm)
115+
116+
print(f"Successfully compiled to {vmfb_file}")

0 commit comments

Comments
 (0)