@@ -1143,59 +1143,88 @@ impl BlockContext<'_> {
11431143 ) ,
11441144 } ,
11451145 fun @ ( Mf :: Dot4I8Packed | Mf :: Dot4U8Packed ) => {
1146- // TODO: consider using packed integer dot product if PackedVectorFormat4x8Bit is available
1147- let ( extract_op, arg0_id, arg1_id) = match fun {
1148- Mf :: Dot4U8Packed => ( spirv:: Op :: BitFieldUExtract , arg0_id, arg1_id) ,
1149- Mf :: Dot4I8Packed => {
1150- // Convert both packed arguments to signed integers so that we can apply the
1151- // `BitFieldSExtract` operation on them in `write_dot_product` below.
1152- let new_arg0_id = self . gen_id ( ) ;
1153- block. body . push ( Instruction :: unary (
1154- spirv:: Op :: Bitcast ,
1155- result_type_id,
1156- new_arg0_id,
1157- arg0_id,
1158- ) ) ;
1146+ if self
1147+ . writer
1148+ . require_all ( & [
1149+ spirv:: Capability :: DotProduct ,
1150+ spirv:: Capability :: DotProductInput4x8BitPacked ,
1151+ ] )
1152+ . is_ok ( )
1153+ {
1154+ // Write optimized code using `PackedVectorFormat4x8Bit`.
1155+ self . writer . use_extension ( "SPV_KHR_integer_dot_product" ) ;
1156+
1157+ let op = match fun {
1158+ Mf :: Dot4I8Packed => spirv:: Op :: SDot ,
1159+ Mf :: Dot4U8Packed => spirv:: Op :: UDot ,
1160+ _ => unreachable ! ( ) ,
1161+ } ;
11591162
1160- let new_arg1_id = self . gen_id ( ) ;
1161- block. body . push ( Instruction :: unary (
1162- spirv:: Op :: Bitcast ,
1163- result_type_id,
1164- new_arg1_id,
1165- arg1_id,
1166- ) ) ;
1163+ block. body . push ( Instruction :: ternary (
1164+ op,
1165+ result_type_id,
1166+ id,
1167+ arg0_id,
1168+ arg1_id,
1169+ spirv:: PackedVectorFormat :: PackedVectorFormat4x8Bit as Word ,
1170+ ) ) ;
1171+ } else {
1172+ // Fall back to a polyfill since `PackedVectorFormat4x8Bit` is not available.
1173+ let ( extract_op, arg0_id, arg1_id) = match fun {
1174+ Mf :: Dot4U8Packed => ( spirv:: Op :: BitFieldUExtract , arg0_id, arg1_id) ,
1175+ Mf :: Dot4I8Packed => {
1176+ // Convert both packed arguments to signed integers so that we can apply the
1177+ // `BitFieldSExtract` operation on them in `write_dot_product` below.
1178+ let new_arg0_id = self . gen_id ( ) ;
1179+ block. body . push ( Instruction :: unary (
1180+ spirv:: Op :: Bitcast ,
1181+ result_type_id,
1182+ new_arg0_id,
1183+ arg0_id,
1184+ ) ) ;
11671185
1168- ( spirv:: Op :: BitFieldSExtract , new_arg0_id, new_arg1_id)
1169- }
1170- _ => unreachable ! ( ) ,
1171- } ;
1186+ let new_arg1_id = self . gen_id ( ) ;
1187+ block. body . push ( Instruction :: unary (
1188+ spirv:: Op :: Bitcast ,
1189+ result_type_id,
1190+ new_arg1_id,
1191+ arg1_id,
1192+ ) ) ;
11721193
1173- let eight = self . writer . get_constant_scalar ( crate :: Literal :: U32 ( 8 ) ) ;
1194+ ( spirv:: Op :: BitFieldSExtract , new_arg0_id, new_arg1_id)
1195+ }
1196+ _ => unreachable ! ( ) ,
1197+ } ;
11741198
1175- const VEC_LENGTH : u8 = 4 ;
1176- let bit_shifts: [ _ ; VEC_LENGTH as usize ] = core:: array:: from_fn ( |index| {
1177- self . writer
1178- . get_constant_scalar ( crate :: Literal :: U32 ( index as u32 * 8 ) )
1179- } ) ;
1199+ let eight = self . writer . get_constant_scalar ( crate :: Literal :: U32 ( 8 ) ) ;
1200+
1201+ const VEC_LENGTH : u8 = 4 ;
1202+ let bit_shifts: [ _ ; VEC_LENGTH as usize ] =
1203+ core:: array:: from_fn ( |index| {
1204+ self . writer
1205+ . get_constant_scalar ( crate :: Literal :: U32 ( index as u32 * 8 ) )
1206+ } ) ;
1207+
1208+ self . write_dot_product (
1209+ id,
1210+ result_type_id,
1211+ arg0_id,
1212+ arg1_id,
1213+ VEC_LENGTH as Word ,
1214+ block,
1215+ |result_id, composite_id, index| {
1216+ Instruction :: ternary (
1217+ extract_op,
1218+ result_type_id,
1219+ result_id,
1220+ composite_id,
1221+ bit_shifts[ index as usize ] ,
1222+ eight,
1223+ )
1224+ } ,
1225+ ) ;
1226+ }
11801227
1181- self . write_dot_product (
1182- id,
1183- result_type_id,
1184- arg0_id,
1185- arg1_id,
1186- VEC_LENGTH as Word ,
1187- block,
1188- |result_id, composite_id, index| {
1189- Instruction :: ternary (
1190- extract_op,
1191- result_type_id,
1192- result_id,
1193- composite_id,
1194- bit_shifts[ index as usize ] ,
1195- eight,
1196- )
1197- } ,
1198- ) ;
11991228 self . cached [ expr_handle] = id;
12001229 return Ok ( ( ) ) ;
12011230 }
0 commit comments