Skip to content

Commit 1621801

Browse files
author
morelos
committed
Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
2 parents 80ecb39 + 2825849 commit 1621801

25 files changed

+220
-93
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -59,64 +59,94 @@
5959
"double": "image3D",
6060
"float": "image3D",
6161
"half": "image3D",
62-
"int": "iimage3D",
63-
"uint": "uimage3D",
62+
# integer dtypes
6463
"int8": "iimage3D",
6564
"uint8": "uimage3D",
66-
"bool": "uimage3D",
67-
"short": "iimage3D",
65+
"int16": "iimage3D",
6866
"uint16": "uimage3D",
67+
"int32": "iimage3D",
68+
"uint32": "uimage3D",
69+
"int64": "iimage3D",
70+
"uint64": "uimage3D",
71+
# common dtype aliases
72+
"bool": "uimage3D",
73+
"int": "iimage3D",
74+
"uint": "uimage3D",
6975
},
7076
2: {
7177
"double": "image2D",
7278
"float": "image2D",
7379
"half": "image2D",
74-
"int": "iimage2D",
75-
"uint": "uimage2D",
80+
# integer dtypes
7681
"int8": "iimage2D",
7782
"uint8": "uimage2D",
78-
"bool": "uimage2D",
79-
"short": "iimage2D",
83+
"int16": "iimage2D",
8084
"uint16": "uimage2D",
85+
"int32": "iimage2D",
86+
"uint32": "uimage2D",
87+
"int64": "iimage2D",
88+
"uint64": "uimage2D",
89+
# common dtype aliases
90+
"bool": "uimage2D",
91+
"int": "iimage2D",
92+
"uint": "uimage2D",
8193
},
8294
},
8395
"SAMPLER_T": {
8496
3: {
8597
"double": "sampler3D",
8698
"float": "sampler3D",
8799
"half": "sampler3D",
88-
"int": "isampler3D",
89-
"uint": "usampler3D",
100+
# integer dtypes
90101
"int8": "isampler3D",
91102
"uint8": "usampler3D",
92-
"bool": "usampler3D",
93-
"short": "isampler3D",
103+
"int16": "isampler3D",
94104
"uint16": "usampler3D",
105+
"int32": "isampler3D",
106+
"uint32": "usampler3D",
107+
"int64": "isampler3D",
108+
"uint64": "usampler3D",
109+
# common dtype aliases
110+
"bool": "usampler3D",
111+
"int": "isampler3D",
112+
"uint": "usampler3D",
95113
},
96114
2: {
97115
"double": "sampler2D",
98116
"float": "sampler2D",
99117
"half": "sampler2D",
100-
"int": "isampler2D",
101-
"uint": "usampler2D",
118+
# integer dtypes
102119
"int8": "isampler2D",
103120
"uint8": "usampler2D",
104-
"bool": "usampler2D",
105-
"short": "isampler2D",
121+
"int16": "isampler2D",
106122
"uint16": "usampler2D",
123+
"int32": "isampler2D",
124+
"uint32": "usampler2D",
125+
"int64": "isampler2D",
126+
"uint64": "usampler2D",
127+
# common dtype aliases
128+
"bool": "usampler2D",
129+
"int": "isampler2D",
130+
"uint": "usampler2D",
107131
},
108132
},
109133
"IMAGE_FORMAT": {
110-
"double": "rgba64f",
134+
"double": "rgba32f",
111135
"float": "rgba32f",
112136
"half": "rgba16f",
113-
"int": "rgba32i",
114-
"uint": "rgba32ui",
137+
# integer dtypes
115138
"int8": "rgba8i",
116139
"uint8": "rgba8ui",
117-
"bool": "rgba8ui",
118-
"short": "rgba16i",
140+
"int16": "rgba16i",
119141
"uint16": "rgba16ui",
142+
"int32": "rgba32i",
143+
"uint32": "rgba32ui",
144+
"int64": "rgba32i",
145+
"uint64": "rgba32ui",
146+
# common dtype aliases
147+
"bool": "rgba8ui",
148+
"int": "rgba32i",
149+
"uint": "rgba32ui",
120150
},
121151
}
122152

@@ -137,10 +167,12 @@ def buffer_scalar_type(dtype: str) -> str:
137167
return "float"
138168
elif dtype == "double":
139169
return "float64_t"
140-
elif dtype == "short":
141-
return "int16_t"
170+
# integer dtype alias conversion
142171
elif dtype == "bool":
143172
return "uint8_t"
173+
# we don't want to append _t for int32 or uint32 as int is already 32bit
174+
elif dtype == "int32" or dtype == "uint32":
175+
return "int" if dtype == "int32" else "uint"
144176
elif dtype[-1].isdigit():
145177
return dtype + "_t"
146178
return dtype
@@ -150,26 +182,29 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
150182
if n == 1:
151183
return buffer_scalar_type(dtype)
152184

153-
if dtype == "float":
154-
return f"vec{n}"
155-
if dtype == "uint":
156-
return f"uvec{n}"
157-
elif dtype == "half":
185+
if dtype == "half":
158186
return f"f16vec{n}"
187+
elif dtype == "float":
188+
return f"vec{n}"
159189
elif dtype == "double":
160-
return f"dvec{n}"
161-
elif dtype == "int":
162-
return f"ivec{n}"
163-
elif dtype == "short":
164-
return f"i16vec{n}"
165-
elif dtype == "uint16":
166-
return f"u16vec{n}"
190+
return f"f64vec{n}"
191+
# integer dtype
167192
elif dtype == "int8":
168193
return f"i8vec{n}"
169194
elif dtype == "uint8":
170195
return f"u8vec{n}"
171-
elif dtype == "bool":
172-
return f"u8vec{n}"
196+
elif dtype == "int16":
197+
return f"i16vec{n}"
198+
elif dtype == "uint16":
199+
return f"u16vec{n}"
200+
elif dtype == "int32" or dtype == "int":
201+
return f"ivec{n}"
202+
elif dtype == "uint32" or dtype == "uint":
203+
return f"uvec{n}"
204+
elif dtype == "int64":
205+
return f"i64vec{n}"
206+
elif dtype == "uint64":
207+
return f"u64vec{n}"
173208

174209
raise AssertionError(f"Invalid dtype: {dtype}")
175210

@@ -387,22 +422,19 @@ def define_required_extensions(dtypes: Union[str, List[str]]):
387422
dtype_list = dtypes if isinstance(dtypes, list) else [dtypes]
388423

389424
for dtype in dtype_list:
390-
nbit = None
391425
glsl_type = None
392426
if dtype == "half":
393-
nbit = "16bit"
394427
glsl_type = "float16"
395-
elif dtype == "short" or dtype == "int16" or dtype == "uint16":
396-
nbit = "16bit"
397-
glsl_type = "int16"
398-
elif dtype == "bool" or dtype == "int8" or dtype == "uint8":
399-
nbit = "8bit"
428+
elif dtype == "double":
429+
glsl_type = "float64"
430+
elif dtype in ["int8", "uint8"]:
400431
glsl_type = "int8"
401-
elif dtype == "double" or dtype == "float64":
402-
out_str += "#extension GL_ARB_gpu_shader_fp64 : require\n"
432+
elif dtype in ["int16", "uint16"]:
433+
glsl_type = "int16"
434+
elif dtype in ["int64", "uint64"]:
435+
glsl_type = "int64"
403436

404-
if nbit is not None and glsl_type is not None:
405-
out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n"
437+
if glsl_type is not None:
406438
out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n"
407439

408440
return out_str
@@ -658,6 +690,8 @@ def generateVariantCombinations(
658690

659691
elif "VALUE" in value:
660692
suffix = value.get("SUFFIX", value["VALUE"])
693+
if value["VALUE"] in ["int", "uint"]:
694+
raise ValueError(f"Use int32 or uint32 instead of {value['VALUE']}")
661695
param_values.append((param_name, suffix, value["VALUE"]))
662696

663697
else:

backends/vulkan/runtime/graph/ops/glsl/arange.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ arange:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17-
- VALUE: int
17+
- VALUE: int32
1818
shader_variants:
1919
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ avg_pool2d:
1313
DTYPE:
1414
- VALUE: half
1515
- VALUE: float
16-
- VALUE: int
16+
- VALUE: int32
1717
shader_variants:
1818
- NAME: avg_pool2d

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ binary_op:
1717
DTYPE:
1818
- VALUE: half
1919
- VALUE: float
20-
- VALUE: int
20+
- VALUE: int32
2121
shader_variants:
2222
- NAME: binary_add
2323
- NAME: binary_sub

backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ buffer_to_buffer:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18-
- VALUE: short
19-
- VALUE: uint16
18+
- VALUE: int32
19+
- VALUE: int64
2020
shader_variants:
2121
- NAME: buffer_to_buffer

backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ buffer_to_nchw:
1212
DTYPE:
1313
- VALUE: half
1414
- VALUE: float
15-
- VALUE: int
15+
- VALUE: double
1616
- VALUE: int8
1717
- VALUE: uint8
18-
- VALUE: short
19-
- VALUE: uint16
18+
- VALUE: int32
19+
- VALUE: int64
2020
shader_variants:
2121
- NAME: buffer_to_nchw

backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_channel_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_channel_offset

backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ copy_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
- VALUE: int8
1212
- VALUE: uint8
1313
STORAGE:

backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ copy_packed_dim_offset:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: copy_packed_dim_offset

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ embedding:
77
DTYPE:
88
- VALUE: half
99
- VALUE: float
10-
- VALUE: int
10+
- VALUE: int32
1111
shader_variants:
1212
- NAME: embedding

0 commit comments

Comments
 (0)