Skip to content

Commit a73b6bf

Browse files
author
morelos
committed
Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
2 parents d40315c + e5e2fc6 commit a73b6bf

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def buffer_scalar_type(dtype: str) -> str:
173173
# we don't want to append _t for int32 or uint32 as int is already 32bit
174174
elif dtype == "int32" or dtype == "uint32":
175175
return "int" if dtype == "int32" else "uint"
176-
elif dtype == "int64" or dtype == "uint64":
177-
return "int" if dtype == "int64" else "uint"
178176
elif dtype[-1].isdigit():
179177
return dtype + "_t"
180178
return dtype
@@ -184,33 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str:
184182
if n == 1:
185183
return buffer_scalar_type(dtype)
186184

187-
if dtype == "half":
188-
return f"f16vec{n}"
189-
elif dtype == "float":
190-
return f"vec{n}"
191-
elif dtype == "double":
192-
return f"vec{n}"
193-
# integer dtype
194-
elif dtype == "int8":
195-
return f"i8vec{n}"
196-
elif dtype == "uint8":
197-
return f"u8vec{n}"
198-
elif dtype == "int16":
199-
return f"i16vec{n}"
200-
elif dtype == "uint16":
201-
return f"u16vec{n}"
202-
elif dtype == "int32" or dtype == "int":
203-
return f"ivec{n}"
204-
elif dtype == "uint32" or dtype == "uint":
205-
return f"uvec{n}"
206-
elif dtype == "int64":
207-
return f"ivec{n}"
208-
elif dtype == "uint64":
209-
return f"uvec{n}"
210-
elif dtype == "bool":
211-
return f"u8vec{n}"
212-
213-
raise AssertionError(f"Invalid dtype: {dtype}")
185+
dtype_map = {
186+
"half": f"f16vec{n}",
187+
"float": f"vec{n}",
188+
"double": f"vec{n}", # No 64bit support in GLSL
189+
"int8": f"i8vec{n}",
190+
"uint8": f"u8vec{n}",
191+
"int16": f"i16vec{n}",
192+
"uint16": f"u16vec{n}",
193+
"int32": f"ivec{n}",
194+
"int": f"ivec{n}",
195+
"uint32": f"uvec{n}",
196+
"uint": f"uvec{n}",
197+
"int64": f"ivec{n}", # No 64bit support in GLSL
198+
"uint64": f"uvec{n}", # No 64bit support in GLSL
199+
"bool": f"u8vec{n}",
200+
}
201+
202+
vector_type = dtype_map.get(dtype)
203+
if vector_type is None:
204+
raise AssertionError(f"Invalid dtype: {dtype}")
205+
206+
return vector_type
214207

215208

216209
def texel_type(dtype: str) -> str:

0 commit comments

Comments
 (0)