2121# Registry to track all ops with reference implementations
2222_REGISTERED_REF_IMPLEMENTATIONS : set [str ] = set ()
2323
24+ _OUTPUTS_TYPE = torch .Tensor | tuple [torch .Tensor , ...]
25+
2426
2527# Custom impl wrapper that tracks registrations
2628def impl_tracked (
2729 lib : Library , op_name : str
28- ) -> Callable [[Callable [..., torch . Tensor ]], Callable [..., torch . Tensor ]]:
30+ ) -> Callable [[Callable [..., _OUTPUTS_TYPE ]], Callable [..., _OUTPUTS_TYPE ]]:
2931 """Wrapper around impl that tracks registered ops."""
3032 _REGISTERED_REF_IMPLEMENTATIONS .add (op_name )
3133 return impl (lib , op_name )
@@ -312,7 +314,7 @@ def quantized_add_per_tensor(
312314 dequant_Y = Y_scale * (Y - Y_zero_point )
313315
314316 # q_min/q_max are unused args
315- return quantize_per_tensor (
317+ out = quantize_per_tensor (
316318 dequant_X + dequant_Y ,
317319 out_scale ,
318320 out_zero_point ,
@@ -321,6 +323,9 @@ def quantized_add_per_tensor(
321323 dtype ,
322324 )
323325
326+ assert isinstance (out , torch .Tensor )
327+ return out
328+
324329
325330@impl_tracked (m , "quantized_add_asym8sxasym8s_asym8s.per_tensor" )
326331def quantized_add_asym8sxasym8s_asym8s_per_tensor (
@@ -338,9 +343,11 @@ def quantized_add_asym8sxasym8s_asym8s_per_tensor(
338343 if Y .dtype != torch .int8 :
339344 raise ValueError ("Y dtype must be torch.int8" )
340345
341- return quantized_add_per_tensor (
346+ out = quantized_add_per_tensor (
342347 X , X_scale , X_zero_point , Y , Y_scale , Y_zero_point , out_scale , out_zero_point
343348 )
349+ assert isinstance (out , torch .Tensor )
350+ return out
344351
345352
346353@impl_tracked (m , "quantized_add_asym8uxasym8u_asym8u.per_tensor" )
@@ -359,9 +366,11 @@ def quantized_add_asym8uxasym8u_asym8u_per_tensor(
359366 if Y .dtype != torch .uint8 :
360367 raise ValueError ("Y dtype must be torch.int8" )
361368
362- return quantized_add_per_tensor (
369+ out = quantized_add_per_tensor (
363370 X , X_scale , X_zero_point , Y , Y_scale , Y_zero_point , out_scale , out_zero_point
364371 )
372+ assert isinstance (out , torch .Tensor )
373+ return out
365374
366375
367376def quantized_linear_common (
@@ -407,14 +416,16 @@ def quantized_linear_common(
407416 (weight - weight_zero_point ).float (),
408417 bias .float (),
409418 )
410- return quantize_per_tensor (
419+ out = quantize_per_tensor (
411420 out ,
412421 out_scale ,
413422 out_zero_point ,
414423 torch .iinfo (dtype ).min ,
415424 torch .iinfo (dtype ).max ,
416425 dtype ,
417- ).reshape (* leading_dims , N )
426+ )
427+ assert isinstance (out , torch .Tensor )
428+ return out .reshape (* leading_dims , N )
418429
419430
420431def quantized_linear_variant (
@@ -576,14 +587,16 @@ def quantized_matmul(
576587 (X - X_zero_point ).float (),
577588 (Y - Y_zero_point ).float (),
578589 )
579- return quantize_per_tensor (
590+ out = quantize_per_tensor (
580591 out ,
581592 out_scale ,
582593 out_zero_point ,
583594 torch .iinfo (X .dtype ).min ,
584595 torch .iinfo (X .dtype ).max ,
585596 X .dtype ,
586597 )
598+ assert isinstance (out , torch .Tensor )
599+ return out
587600
588601
589602@impl_tracked (m , "quantized_matmul_asym8sxasym8s_asym8s" )
@@ -603,7 +616,7 @@ def quantized_matmul_asym8sxasym8s_asym8s(
603616 if Y .dtype != torch .int8 :
604617 raise ValueError ("Y dtype must be torch.int8" )
605618
606- return quantized_matmul (
619+ out = quantized_matmul (
607620 X ,
608621 X_zero_point ,
609622 Y ,
@@ -614,6 +627,8 @@ def quantized_matmul_asym8sxasym8s_asym8s(
614627 out_zero_point ,
615628 transposed ,
616629 )
630+ assert isinstance (out , torch .Tensor )
631+ return out
617632
618633
619634@impl_tracked (m , "quantized_matmul_asym8uxasym8u_asym8u" )
@@ -633,7 +648,7 @@ def quantized_matmul_asym8uxasym8u_asym8u(
633648 if Y .dtype != torch .uint8 :
634649 raise ValueError ("Y dtype must be torch.uint8" )
635650
636- return quantized_matmul (
651+ out = quantized_matmul (
637652 X ,
638653 X_zero_point ,
639654 Y ,
@@ -644,6 +659,8 @@ def quantized_matmul_asym8uxasym8u_asym8u(
644659 out_zero_point ,
645660 transposed ,
646661 )
662+ assert isinstance (out , torch .Tensor )
663+ return out
647664
648665
649666@impl_tracked (m , "quantized_layer_norm.per_tensor" )
@@ -681,18 +698,21 @@ def quantized_layer_norm_per_tensor(
681698 float_input_tensor = dequantize_per_tensor (
682699 input_tensor , X_scale , X_zero_point , - 128 , 127 , input_tensor .dtype
683700 )
701+ assert isinstance (float_input_tensor , torch .Tensor )
684702 out = torch .nn .functional .layer_norm (
685703 float_input_tensor , normalized_shape , weight , bias , eps = eps
686704 )
687705
688- return quantize_per_tensor (
706+ out = quantize_per_tensor (
689707 out ,
690708 output_scale ,
691709 output_zero_point ,
692710 torch .iinfo (input_tensor .dtype ).min ,
693711 torch .iinfo (input_tensor .dtype ).max ,
694712 input_tensor .dtype ,
695713 )
714+ assert isinstance (out , torch .Tensor )
715+ return out
696716
697717
698718def quantized_conv_per_tensor (
@@ -754,14 +774,16 @@ def quantized_conv_per_tensor(
754774 else :
755775 raise ValueError ("Input tensor must be 3D or 4D" )
756776
757- return quantize_per_tensor (
777+ out = quantize_per_tensor (
758778 float_out ,
759779 output_scale ,
760780 output_zero_point ,
761781 torch .iinfo (input_tensor .dtype ).min ,
762782 torch .iinfo (input_tensor .dtype ).max ,
763783 input_tensor .dtype ,
764784 )
785+ assert isinstance (out , torch .Tensor )
786+ return out
765787
766788
767789@impl_tracked (m , "quantized_conv2d_nchw.per_tensor" )
@@ -971,7 +993,7 @@ def variant(
971993 # Call the appropriate base function
972994 match layout :
973995 case "nchw" :
974- return quantized_conv2d_nchw_per_tensor (
996+ out = quantized_conv2d_nchw_per_tensor (
975997 input_tensor ,
976998 weight ,
977999 bias ,
@@ -988,7 +1010,7 @@ def variant(
9881010 out_shift ,
9891011 )
9901012 case "nhwc" :
991- return quantized_conv2d_nhwc_per_tensor (
1013+ out = quantized_conv2d_nhwc_per_tensor (
9921014 input_tensor ,
9931015 weight ,
9941016 bias ,
@@ -1007,6 +1029,9 @@ def variant(
10071029 case _:
10081030 raise ValueError (f"Unknown layout { layout } " )
10091031
1032+ assert isinstance (out , torch .Tensor )
1033+ return out
1034+
10101035 return variant
10111036
10121037 return decorator
@@ -1281,14 +1306,16 @@ def quantized_relu_common(
12811306 dequantized_X = torch .where (
12821307 X > X_zero_point , X - X_zero_point , torch .zeros_like (X )
12831308 ).to (torch .float32 )
1284- return quantize_per_tensor (
1309+ out = quantize_per_tensor (
12851310 dequantized_X ,
12861311 out_scale ,
12871312 out_zero_point ,
12881313 torch .iinfo (X .dtype ).min ,
12891314 torch .iinfo (X .dtype ).max ,
12901315 X .dtype ,
12911316 )
1317+ assert isinstance (out , torch .Tensor )
1318+ return out
12921319
12931320
12941321def quantized_relu_variant (
@@ -1545,7 +1572,7 @@ def im2row_per_tensor(
15451572 in_zero_point : int ,
15461573 channel_last : bool = False ,
15471574) -> torch .Tensor :
1548- return im2row (
1575+ out = im2row (
15491576 input_tensor ,
15501577 kernel_size ,
15511578 dilation ,
@@ -1554,6 +1581,8 @@ def im2row_per_tensor(
15541581 torch .tensor (in_zero_point , dtype = torch .int32 ),
15551582 channel_last ,
15561583 )
1584+ assert isinstance (out , torch .Tensor )
1585+ return out
15571586
15581587
15591588@impl_tracked (m , "transposed_im2row" )
@@ -1761,3 +1790,15 @@ def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.T
17611790@impl_tracked (m , "idma_wait" )
17621791def idma_wait (src : torch .Tensor , task_num : int = 0 , channel : int = 0 ) -> torch .Tensor :
17631792 return src .clone ()
1793+
1794+
1795+ @impl_tracked (m , "linalg_svd" )
1796+ def linalg_svd (
1797+ A : torch .Tensor ,
1798+ full_matrices : bool = False ,
1799+ compute_uv : bool = True ,
1800+ driver : str | None = None ,
1801+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1802+ assert compute_uv
1803+ U , S , Vh = torch .linalg .svd (A , full_matrices = full_matrices , driver = driver )
1804+ return U .contiguous (), S .contiguous (), Vh .contiguous ()
0 commit comments