15
15
16
16
import torch
17
17
import torch .nn .utils .parametrize as parametrize
18
+ from torch .utils ._python_dispatch import return_and_correct_aliasing
18
19
19
20
__all__ = [
20
21
"benchmark_model" ,
@@ -409,6 +410,9 @@ def _(func, types, args, kwargs):
409
410
if not hasattr (cls , "_ATEN_OP_OR_TORCH_FN_TABLE" ):
410
411
cls ._ATEN_OP_OR_TORCH_FN_TABLE = {}
411
412
413
+ if cls not in cls ._ATEN_OP_OR_TORCH_FN_TABLE :
414
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ] = {}
415
+
412
416
if not isinstance (aten_ops_or_torch_fns , (list , tuple )):
413
417
aten_ops_or_torch_fns = [aten_ops_or_torch_fns ]
414
418
@@ -419,12 +423,83 @@ def decorator(func):
419
423
def wrapper (f , types , args , kwargs ):
420
424
return func (f , types , args , kwargs )
421
425
422
- cls ._ATEN_OP_OR_TORCH_FN_TABLE [op ] = wrapper
426
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ][ op ] = wrapper
423
427
return func
424
428
425
429
return decorator
426
430
427
431
432
+ def _implements_common_tensor_ops (cls ):
433
+ implements = cls .implements
434
+ aten = torch .ops .aten
435
+
436
+ @implements (
437
+ [aten .detach .default , aten .clone .default , aten .alias .default , aten .contiguous ]
438
+ )
439
+ def _ (func , types , args , kwargs ):
440
+ return return_and_correct_aliasing (
441
+ func ,
442
+ args ,
443
+ kwargs ,
444
+ args [0 ]._apply_fn_to_data (lambda x : func (x , * args [1 :], ** kwargs )),
445
+ )
446
+
447
+ def _same_metadata (self : TorchAOBaseTensor , src : TorchAOBaseTensor ) -> bool :
448
+ _tensor_shape_match = all (
449
+ getattr (self , t_name ).shape == getattr (src , t_name ).shape
450
+ for t_name in self .tensor_data_names
451
+ )
452
+ _attr_match = all (
453
+ getattr (self , a_name ) == getattr (src , a_name )
454
+ for a_name in self .tensor_attribute_names
455
+ )
456
+ return (
457
+ type (self ) == type (src )
458
+ and self .shape == src .shape
459
+ and _tensor_shape_match
460
+ and _attr_match
461
+ )
462
+
463
+ @implements (aten .copy_ .default )
464
+ def _ (func , types , args , kwargs ):
465
+ self = args [0 ]
466
+ src = args [1 ]
467
+ if _same_metadata (self , src ):
468
+ self_tensors = self .__tensor_flatten__ ()[0 ]
469
+ for tensor_name in self_tensors :
470
+ getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
471
+ return
472
+ raise ValueError (
473
+ f"Not supported args for copy_ due to metadata mismatch: { args [0 ], args [1 ]} "
474
+ )
475
+
476
+ @implements (aten ._to_copy .default )
477
+ def _ (func , types , args , kwargs ):
478
+ self = args [0 ]
479
+ if hasattr (self , "tensor_data_names" ) and hasattr (
480
+ self , "tensor_attribute_names"
481
+ ):
482
+ kwargs = self ._get_to_kwargs (* args [1 :], ** kwargs )
483
+ device = kwargs .pop ("device" )
484
+ tensors = [
485
+ getattr (self , name ).to (device ) for name in self .tensor_data_names
486
+ ]
487
+ # change device
488
+ tensor_attributes = [
489
+ getattr (self , attr_name ) if attr_name != "device" else device
490
+ for attr_name in self .tensor_attribute_names
491
+ ]
492
+ t = self .__class__ (
493
+ * tensors ,
494
+ * tensor_attributes ,
495
+ )
496
+ return return_and_correct_aliasing (func , args , kwargs , t )
497
+
498
+ raise NotImplementedError (
499
+ "Subclasses must implement `aten._to_copy.default` or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
500
+ )
501
+
502
+
428
503
def _dispatch__torch_function__ (cls , func , types , args = (), kwargs = None ):
429
504
"""Use this util function for a common `__torch_function__` implementation
430
505
that dispatches to ops/functions registered with `_implements`
@@ -436,9 +511,10 @@ class MyTensor(torch.Tensor):
436
511
kwargs = {} if kwargs is None else kwargs
437
512
if (
438
513
hasattr (cls , "_ATEN_OP_OR_TORCH_FN_TABLE" )
439
- and func in cls ._ATEN_OP_OR_TORCH_FN_TABLE
514
+ and cls in cls ._ATEN_OP_OR_TORCH_FN_TABLE
515
+ and func in cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ]
440
516
):
441
- return cls ._ATEN_OP_OR_TORCH_FN_TABLE [func ](func , types , args , kwargs )
517
+ return cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ][ func ](func , types , args , kwargs )
442
518
443
519
with torch ._C .DisableTorchFunctionSubclass ():
444
520
return func (* args , ** kwargs )
@@ -454,9 +530,10 @@ class MyTensor(torch.Tensor):
454
530
"""
455
531
if (
456
532
hasattr (cls , "_ATEN_OP_OR_TORCH_FN_TABLE" )
457
- and func in cls ._ATEN_OP_OR_TORCH_FN_TABLE
533
+ and cls in cls ._ATEN_OP_OR_TORCH_FN_TABLE
534
+ and func in cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ]
458
535
):
459
- return cls ._ATEN_OP_OR_TORCH_FN_TABLE [func ](func , types , args , kwargs )
536
+ return cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ][ func ](func , types , args , kwargs )
460
537
461
538
arg_types = tuple (type (arg ) for arg in args )
462
539
kwarg_types = {k : type (arg ) for k , arg in kwargs .items ()}
@@ -576,7 +653,28 @@ class PlainAQTTensorImpl(...):
576
653
577
654
"""
578
655
656
+ @classmethod
657
+ def __init_subclass__ (cls , ** kwargs ):
658
+ if not hasattr (cls , "_ATEN_OP_OR_TORCH_FN_TABLE" ):
659
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE = {}
660
+
661
+ if cls not in cls ._ATEN_OP_OR_TORCH_FN_TABLE :
662
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ] = {}
663
+
664
+ # define the common ops if the tensor_data_names and tensor_attribute_names are defined
665
+ if hasattr (cls , "tensor_data_names" ) and hasattr (cls , "tensor_attribute_names" ):
666
+ cls ._implements_common_tensor_ops ()
667
+
668
+ # inherit the torch function and dispatch implementations from direct parent classes
669
+ # e.g. for `class C(B, A)`, C.__bases__ == (B, A)
670
+ for parent in cls .__bases__ :
671
+ if parent in cls ._ATEN_OP_OR_TORCH_FN_TABLE :
672
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE [cls ].update (
673
+ cls ._ATEN_OP_OR_TORCH_FN_TABLE [parent ]
674
+ )
675
+
579
676
implements = classmethod (_implements )
677
+ _implements_common_tensor_ops = classmethod (_implements_common_tensor_ops )
580
678
__torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
581
679
__torch_function__ = classmethod (_dispatch__torch_function__ )
582
680
register_layout = classmethod (_register_layout )
@@ -591,7 +689,7 @@ def __tensor_flatten__(self):
591
689
getattr (self , attr ) for attr in self .tensor_attribute_names
592
690
]
593
691
raise NotImplementedError (
594
- "Subclasses must implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance"
692
+ "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it "
595
693
)
596
694
597
695
@classmethod
@@ -602,13 +700,20 @@ def __tensor_unflatten__(
602
700
return cls (* tensors , * tensor_attributes )
603
701
604
702
def _apply_fn_to_data (self , fn ):
605
- tensors = [fn (getattr (self , attr )) for attr in self .tensor_data_names ]
606
- tensor_attributes = [
607
- getattr (self , attr ) for attr in self .tensor_attribute_names
608
- ]
609
- return self .__class__ (
610
- * tensors ,
611
- * tensor_attributes ,
703
+ if hasattr (self , "tensor_data_names" ) and hasattr (
704
+ self , "tensor_attribute_names"
705
+ ):
706
+ tensors = [fn (getattr (self , attr )) for attr in self .tensor_data_names ]
707
+ tensor_attributes = [
708
+ getattr (self , attr ) for attr in self .tensor_attribute_names
709
+ ]
710
+ return self .__class__ (
711
+ * tensors ,
712
+ * tensor_attributes ,
713
+ )
714
+
715
+ raise NotImplementedError (
716
+ "Subclasses should implement _apply_fn_to_data or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
612
717
)
613
718
614
719
def __repr__ (self ):
@@ -624,7 +729,10 @@ def __repr__(self):
624
729
f", { tensor_attribute_name } ={ getattr (self , tensor_attribute_name )} "
625
730
)
626
731
return f"{ self .__class__ .__name__ } ({ repr_str } )"
627
- raise NotImplementedError ("Subclasses must implement __repr__" )
732
+
733
+ raise NotImplementedError (
734
+ "Subclasses must implement __repr__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it"
735
+ )
628
736
629
737
def get_layout (self ):
630
738
if not hasattr (self , "_layout" ):
0 commit comments