7
7
from abc import ABC , abstractmethod
8
8
from dataclasses import dataclass , field
9
9
from functools import lru_cache
10
- from typing import Any , Callable , Dict , List , Set , Tuple , Union
10
+ from typing import Any , Callable , Dict , List , Set , Tuple , Union , Optional
11
11
12
12
import torch
13
+ from tqdm import tqdm
13
14
14
15
# from tensorrt_llm.bindings.internal.runtime import delay_kernel
15
16
# from tensorrt_llm.logger import logger
@@ -42,16 +43,20 @@ class DynamicTensorSpec:
42
43
"""
43
44
A specification for a dynamic tensor dimension.
44
45
Args:
45
- input_idx: The index of the input tensor.
46
- dim_idx: The index of the dimension to tune.
46
+ input_idx: A list of the indices of the input tensors.
47
+ dim_idx: A list of the indices of the dimensions to tune.
48
+ The length of input_idx and dim_idx must be the same.
49
+ For every tensor mapped to the input_idx, their dimension mapped to the dim_idx must be the same.
47
50
gen_tuning_buckets: A tuple of values to try or a function generating values.
48
51
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
52
+ tensor_initializers: A list of functions to initialize the tensors.
49
53
"""
50
54
51
- input_idx : int
52
- dim_idx : int
55
+ input_idx : Tuple [ int ]
56
+ dim_idx : Tuple [ int ]
53
57
gen_tuning_buckets : Union [Tuple [int ], Callable ]
54
58
map_to_tuning_buckets : Callable
59
+ # tensor_initializers: Tuple[Callable] = field(default_factory=lambda: [lambda shapes, dtype, device: torch.randn(shapes, device=device, dtype=dtype)])
55
60
56
61
57
62
@dataclass (slots = True , unsafe_hash = True )
@@ -85,8 +90,8 @@ class TuningConfig:
85
90
>>> config = TuningConfig(
86
91
... dynamic_tensor_specs=(
87
92
... DynamicTensorSpec(
88
- ... input_idx=0 ,
89
- ... dim_idx=1 ,
93
+ ... input_idx=[0] ,
94
+ ... dim_idx=[1] ,
90
95
... gen_tuning_buckets=(32, 64, 128),
91
96
... map_to_tuning_buckets=lambda x: ((x + 31) // 32) * 32
92
97
... ),
@@ -426,7 +431,7 @@ def choose_one(
426
431
"All Given runners must be subclass of TunableRunner"
427
432
)
428
433
429
- profiles = self ._optimization_profiles (tuning_config , inputs )
434
+ profiles = self ._generate_optimization_profiles (tuning_config , inputs )
430
435
# Record the total configs to try
431
436
self .stats .tuned_op_total_configs [custom_op ] = len (profiles )
432
437
@@ -532,7 +537,8 @@ def _profile_single_kernel(
532
537
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
533
538
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
534
539
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
535
- delay_kernel (self .stream_delay_micro_secs )
540
+ if self .stream_delay_micro_secs > 0 :
541
+ delay_kernel (self .stream_delay_micro_secs )
536
542
start = torch .cuda .Event (enable_timing = True )
537
543
end = torch .cuda .Event (enable_timing = True )
538
544
@@ -551,7 +557,7 @@ def _profile_single_kernel(
551
557
552
558
return avg_time
553
559
554
- def _optimization_profiles (
560
+ def _generate_optimization_profiles (
555
561
self , tuning_config : TuningConfig , inputs : List [torch .Tensor ]
556
562
) -> List [OptimizationProfile ]:
557
563
"""Generate optimization profiles for autotuning.
@@ -592,9 +598,12 @@ def _optimization_profiles(
592
598
), (
593
599
"The given dynamic dimension must provide a opt value generation function or a list of opt values"
594
600
)
601
+ assert len (spec .input_idx ) == len (spec .dim_idx ), (
602
+ "The number of input indices and dimension indices must be the same"
603
+ )
595
604
if inspect .isfunction (spec .gen_tuning_buckets ):
596
605
opt_shapes = spec .gen_tuning_buckets (
597
- base_profile .shapes [spec .input_idx ] [spec .dim_idx ]._opt ()
606
+ base_profile .shapes [spec .input_idx [ 0 ]] [spec .dim_idx [ 0 ] ]._opt ()
598
607
)
599
608
else :
600
609
opt_shapes = spec .gen_tuning_buckets
@@ -617,9 +626,10 @@ def _optimization_profiles(
617
626
# TODO: fix me, how to set the min and max?
618
627
min_value = opt_value
619
628
max_value = opt_shapes_max [opt_value ]
620
- p .shapes [input_idx ][dim_idx ] = DynamicDim (
621
- min_value , opt_value , max_value
622
- )
629
+ for i in range (len (input_idx )):
630
+ p .shapes [input_idx [i ]][dim_idx [i ]] = DynamicDim (
631
+ min_value , opt_value , max_value
632
+ )
623
633
624
634
# Adjust the profile to satisfy the constraints
625
635
for constraint_spec in tuning_config .constraint_specs :
@@ -653,14 +663,15 @@ def _find_nearest_profile(
653
663
base_profile = list (list (shape ) for shape in shapes )
654
664
655
665
for spec in tuning_config .dynamic_tensor_specs :
656
- base_profile [spec .input_idx ][spec .dim_idx ] = spec .map_to_tuning_buckets (
657
- base_profile [spec .input_idx ][spec .dim_idx ]
666
+ base_profile [spec .input_idx [0 ]][spec .dim_idx [0 ]] = (
667
+ spec .map_to_tuning_buckets (
668
+ base_profile [spec .input_idx [0 ]][spec .dim_idx [0 ]]
669
+ )
658
670
)
659
671
660
672
# associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile
661
673
for constraint_spec in tuning_config .constraint_specs :
662
674
base_profile [constraint_spec .input_idx ][constraint_spec .dim_idx ] = - 1
663
-
664
675
return tuple (tuple (shape ) for shape in base_profile )
665
676
666
677
@classmethod
@@ -679,7 +690,7 @@ def _get_cache_key(
679
690
)
680
691
681
692
def _create_tensor_like (
682
- self , origin_tensor : torch .Tensor , dims : List [Dim ]
693
+ self , origin_tensor : torch .Tensor , dims : List [Dim ], initializer : Callable
683
694
) -> torch .Tensor :
684
695
"""Create a new tensor matching the properties of the original tensor.
685
696
@@ -704,26 +715,21 @@ def _create_tensor_like(
704
715
# TODO: how to make sure the created Tensor has the min/max info
705
716
assert isinstance (d , DynamicDim )
706
717
shapes .append (d .opt )
707
- # TODO: FIXME, sometimes the content of the tensor can affect the performance, like MOE
708
- # One solution is to manituplate the tensor content to make it more like the real data
709
- # during the tuning process. This can by controlled in the preparation phase by the runner.
710
- # return torch.zeros(shapes, dtype=dtype, device=device)
711
- if dtype == torch .int8 :
712
- return torch .randint (0 , 127 , shapes , dtype = dtype , device = device )
713
- elif dtype == torch .uint8 :
714
- return torch .randint (0 , 255 , shapes , dtype = dtype , device = device )
715
- elif dtype == torch .int32 :
716
- return torch .randint (0 , 1000000 , shapes , dtype = dtype , device = device )
717
- else :
718
- return torch .randn (shapes , dtype = dtype , device = device )
718
+ return initializer (shapes , dtype , device )
719
719
720
720
def _prepare_input_tensors (
721
721
self , profile : OptimizationProfile , inputs : List [torch .Tensor ]
722
722
) -> List [torch .Tensor ]:
723
723
tensors = []
724
724
for i , p in enumerate (profile .shapes ):
725
725
if any (isinstance (d , DynamicDim ) for d in p ):
726
- tensor = self ._create_tensor_like (inputs [i ], p )
726
+ tensor = self ._create_tensor_like (
727
+ inputs [i ],
728
+ p ,
729
+ lambda shapes , dtype , device : torch .rand (shapes , device = device ).to (
730
+ dtype
731
+ ),
732
+ )
727
733
else :
728
734
tensor = inputs [i ]
729
735
tensors .append (tensor )
0 commit comments