Skip to content

Commit 52f8e6f

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] dequantize_per_tensor.default test setup"
Creating dequantize_per_tensor testing framework along with a reference implementation for testing Differential Revision: [D76267054](https://our.internmc.facebook.com/intern/diff/D76267054/) [ghstack-poisoned]
1 parent 52dfa85 commit 52f8e6f

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def buffer_scalar_type(dtype: str) -> str:
173173
# we don't want to append _t for int32 or uint32 as int is already 32bit
174174
elif dtype == "int32" or dtype == "uint32":
175175
return "int" if dtype == "int32" else "uint"
176-
elif dtype == "int64" or dtype == "uint64":
177-
return "int" if dtype == "int64" else "uint"
178176
elif dtype[-1].isdigit():
179177
return dtype + "_t"
180178
return dtype
@@ -184,33 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
184182
if n == 1:
185183
return buffer_scalar_type(dtype)
186184

187-
if dtype == "half":
188-
return f"f16vec{n}"
189-
elif dtype == "float":
190-
return f"vec{n}"
191-
elif dtype == "double":
192-
return f"vec{n}"
193-
# integer dtype
194-
elif dtype == "int8":
195-
return f"i8vec{n}"
196-
elif dtype == "uint8":
197-
return f"u8vec{n}"
198-
elif dtype == "int16":
199-
return f"i16vec{n}"
200-
elif dtype == "uint16":
201-
return f"u16vec{n}"
202-
elif dtype == "int32" or dtype == "int":
203-
return f"ivec{n}"
204-
elif dtype == "uint32" or dtype == "uint":
205-
return f"uvec{n}"
206-
elif dtype == "int64":
207-
return f"ivec{n}"
208-
elif dtype == "uint64":
209-
return f"uvec{n}"
210-
elif dtype == "bool":
211-
return f"u8vec{n}"
212-
213-
raise AssertionError(f"Invalid dtype: {dtype}")
185+
dtype_map = {
186+
"half": f"f16vec{n}",
187+
"float": f"vec{n}",
188+
"double": f"vec{n}", # No 64bit support in GLSL
189+
"int8": f"i8vec{n}",
190+
"uint8": f"u8vec{n}",
191+
"int16": f"i16vec{n}",
192+
"uint16": f"u16vec{n}",
193+
"int32": f"ivec{n}",
194+
"int": f"ivec{n}",
195+
"uint32": f"uvec{n}",
196+
"uint": f"uvec{n}",
197+
"int64": f"ivec{n}", # No 64bit support in GLSL
198+
"uint64": f"uvec{n}", # No 64bit support in GLSL
199+
"bool": f"u8vec{n}",
200+
}
201+
202+
vector_type = dtype_map.get(dtype)
203+
if vector_type is None:
204+
raise AssertionError(f"Invalid dtype: {dtype}")
205+
206+
return vector_type
214207

215208

216209
def texel_type(dtype: str) -> str:

0 commit comments

Comments
 (0)