7
7
from abc import ABC , abstractmethod
8
8
from dataclasses import dataclass , field
9
9
from functools import lru_cache
10
- from typing import Any , Callable , Dict , List , Set , Tuple , Union
10
+ from typing import Any , Callable , Dict , List , Set , Tuple , Union , Optional
11
11
12
12
import torch
13
13
@@ -37,21 +37,49 @@ def get_config_path(is_module: bool):
37
37
)
38
38
39
39
40
- @dataclass (slots = True , unsafe_hash = True )
40
+ @dataclass (slots = True )
41
41
class DynamicTensorSpec :
42
42
"""
43
43
A specification for a dynamic tensor dimension.
44
44
Args:
45
- input_idx: The index of the input tensor.
46
- dim_idx: The index of the dimension to tune.
45
+ input_idx: A list of the indices of the input tensors.
46
+ dim_idx: A list of the indices of the dimensions to tune.
47
+ The length of input_idx and dim_idx must be the same.
48
+ For every tensor mapped to the input_idx, their dimension mapped to the dim_idx must be the same.
47
49
gen_tuning_buckets: A tuple of values to try or a function generating values.
48
50
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
51
+ tensor_initializers: A list of functions to initialize the tensors.
49
52
"""
50
53
51
- input_idx : int
52
- dim_idx : int
53
- gen_tuning_buckets : Union [Tuple [int ], Callable ]
54
+ input_idx : Tuple [ int , ...]
55
+ dim_idx : Tuple [ int , ...]
56
+ gen_tuning_buckets : Union [Tuple [int , ... ], Callable ]
54
57
map_to_tuning_buckets : Callable
58
+ tensor_initializers : List [Callable ] = field (default_factory = lambda : None )
59
+
60
+ def __post_init__ (self ):
61
+ # Set default tensor_initializers if not provided
62
+ if self .tensor_initializers is None :
63
+ self .tensor_initializers = [
64
+ lambda shapes , dtype , device : torch .randn (
65
+ shapes , device = device , dtype = dtype
66
+ )
67
+ for _ in range (len (self .input_idx ))
68
+ ]
69
+
70
+ def __hash__ (self ) -> int :
71
+ # FIXME: currently not hasing tensor_initializers
72
+ return hash (
73
+ (
74
+ self .input_idx ,
75
+ self .dim_idx ,
76
+ # For gen_tuning_buckets, only hash if it's a tuple, otherwise hash its id
77
+ self .gen_tuning_buckets
78
+ if isinstance (self .gen_tuning_buckets , tuple )
79
+ else id (self .gen_tuning_buckets ),
80
+ id (self .map_to_tuning_buckets ),
81
+ )
82
+ )
55
83
56
84
57
85
@dataclass (slots = True , unsafe_hash = True )
@@ -85,8 +113,8 @@ class TuningConfig:
85
113
>>> config = TuningConfig(
86
114
... dynamic_tensor_specs=(
87
115
... DynamicTensorSpec(
88
- ... input_idx=0 ,
89
- ... dim_idx=1 ,
116
+ ... input_idx=[0] ,
117
+ ... dim_idx=[1] ,
90
118
... gen_tuning_buckets=(32, 64, 128),
91
119
... map_to_tuning_buckets=lambda x: ((x + 31) // 32) * 32
92
120
... ),
@@ -141,6 +169,7 @@ class OptimizationProfile:
141
169
"""Ranges of all tensors, all dimension"""
142
170
143
171
shapes : List [List [Dim ]]
172
+ tensor_initializers : List [Optional [Callable ]]
144
173
145
174
def get_hash_key (self ):
146
175
return self .get_opt_shapes ()
@@ -426,7 +455,7 @@ def choose_one(
426
455
"All Given runners must be subclass of TunableRunner"
427
456
)
428
457
429
- profiles = self ._optimization_profiles (tuning_config , inputs )
458
+ profiles = self ._generate_optimization_profiles (tuning_config , inputs )
430
459
# Record the total configs to try
431
460
self .stats .tuned_op_total_configs [custom_op ] = len (profiles )
432
461
@@ -532,7 +561,8 @@ def _profile_single_kernel(
532
561
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
533
562
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
534
563
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
535
- delay_kernel (self .stream_delay_micro_secs )
564
+ if self .stream_delay_micro_secs > 0 :
565
+ delay_kernel (self .stream_delay_micro_secs )
536
566
start = torch .cuda .Event (enable_timing = True )
537
567
end = torch .cuda .Event (enable_timing = True )
538
568
@@ -551,7 +581,7 @@ def _profile_single_kernel(
551
581
552
582
return avg_time
553
583
554
- def _optimization_profiles (
584
+ def _generate_optimization_profiles (
555
585
self , tuning_config : TuningConfig , inputs : List [torch .Tensor ]
556
586
) -> List [OptimizationProfile ]:
557
587
"""Generate optimization profiles for autotuning.
@@ -579,7 +609,8 @@ def _optimization_profiles(
579
609
else [StaticDim (0 )]
580
610
)
581
611
for t in inputs
582
- ]
612
+ ],
613
+ [None ] * len (inputs ),
583
614
)
584
615
585
616
generated_profiles : List [OptimizationProfile ] = []
@@ -592,9 +623,18 @@ def _optimization_profiles(
592
623
), (
593
624
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
594
625
)
626
+ assert len (spec .input_idx ) == len (spec .dim_idx ), (
627
+ f"The number of input indices and dimension indices must be the same, got { len (spec .input_idx )} and { len (spec .dim_idx )} "
628
+ )
629
+ assert len (spec .tensor_initializers ) == len (spec .input_idx ), (
630
+ f"The number of tensor initializers and input indices must be the same, got { len (spec .tensor_initializers )} and { len (spec .input_idx )} "
631
+ )
632
+ for i , idx in enumerate (spec .input_idx ):
633
+ base_profile .tensor_initializers [idx ] = spec .tensor_initializers [i ]
634
+
595
635
if inspect .isfunction (spec .gen_tuning_buckets ):
596
636
opt_shapes = spec .gen_tuning_buckets (
597
- base_profile .shapes [spec .input_idx ] [spec .dim_idx ]._opt ()
637
+ base_profile .shapes [spec .input_idx [ 0 ]] [spec .dim_idx [ 0 ] ]._opt ()
598
638
)
599
639
else :
600
640
opt_shapes = spec .gen_tuning_buckets
@@ -617,9 +657,10 @@ def _optimization_profiles(
617
657
# TODO: fix me, how to set the min and max?
618
658
min_value = opt_value
619
659
max_value = opt_shapes_max [opt_value ]
620
- p .shapes [input_idx ][dim_idx ] = DynamicDim (
621
- min_value , opt_value , max_value
622
- )
660
+ for i in range (len (input_idx )):
661
+ p .shapes [input_idx [i ]][dim_idx [i ]] = DynamicDim (
662
+ min_value , opt_value , max_value
663
+ )
623
664
624
665
# Adjust the profile to satisfy the constraints
625
666
for constraint_spec in tuning_config .constraint_specs :
@@ -653,14 +694,15 @@ def _find_nearest_profile(
653
694
base_profile = list (list (shape ) for shape in shapes )
654
695
655
696
for spec in tuning_config .dynamic_tensor_specs :
656
- base_profile [spec .input_idx ][spec .dim_idx ] = spec .map_to_tuning_buckets (
657
- base_profile [spec .input_idx ][spec .dim_idx ]
697
+ base_profile [spec .input_idx [0 ]][spec .dim_idx [0 ]] = (
698
+ spec .map_to_tuning_buckets (
699
+ base_profile [spec .input_idx [0 ]][spec .dim_idx [0 ]]
700
+ )
658
701
)
659
702
660
703
# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
661
704
for constraint_spec in tuning_config .constraint_specs :
662
705
base_profile [constraint_spec .input_idx ][constraint_spec .dim_idx ] = - 1
663
-
664
706
return tuple (tuple (shape ) for shape in base_profile )
665
707
666
708
@classmethod
@@ -679,7 +721,7 @@ def _get_cache_key(
679
721
)
680
722
681
723
def _create_tensor_like (
682
- self , origin_tensor : torch .Tensor , dims : List [Dim ]
724
+ self , origin_tensor : torch .Tensor , dims : List [Dim ], initializer : Callable
683
725
) -> torch .Tensor :
684
726
"""Create a new tensor matching the properties of the original tensor.
685
727
@@ -704,18 +746,22 @@ def _create_tensor_like(
704
746
# TODO: how to make sure the created Tensor has the min/max info
705
747
assert isinstance (d , DynamicDim )
706
748
shapes .append (d .opt )
707
- # TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
708
- # One solution is to manituplate the tensor content to make it more like the real data
709
- # during the tuning process. This can by controlled in the preparation phase by the runner.
710
- return torch .zeros (shapes , dtype = dtype , device = device )
749
+ return initializer (shapes , dtype , device )
711
750
712
751
def _prepare_input_tensors (
713
752
self , profile : OptimizationProfile , inputs : List [torch .Tensor ]
714
753
) -> List [torch .Tensor ]:
754
+ default_initializer = lambda shapes , dtype , device : torch .rand (
755
+ shapes , device = device
756
+ ).to (dtype )
715
757
tensors = []
716
758
for i , p in enumerate (profile .shapes ):
717
759
if any (isinstance (d , DynamicDim ) for d in p ):
718
- tensor = self ._create_tensor_like (inputs [i ], p )
760
+ tensor = self ._create_tensor_like (
761
+ inputs [i ],
762
+ p ,
763
+ profile .tensor_initializers [i ] or default_initializer ,
764
+ )
719
765
else :
720
766
tensor = inputs [i ]
721
767
tensors .append (tensor )
0 commit comments