1515
1616import torch
1717from neuronx_distributed .trace .model_builder import BaseModelInstance
18- from torch_neuronx import BucketModelConfig
1918from transformers import PretrainedConfig
2019
2120from ...config import NxDNeuronConfig
2221from ...graph_builder import NxDGraphBuilder
23- from ..autobucketing import (
24- get_context_encoder_bk ,
25- get_generation_model_bk ,
26- )
2722from ..generation .sampling import prepare_sampling_params
2823
2924
30- CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model"
31- TOKEN_GENERATION_MODEL_TAG = "token_generation_model"
32- SPECULATION_MODEL_TAG = "speculation_model"
33-
34-
35- def get_bucket_model_config_from_tag (
36- tag , config : PretrainedConfig , neuron_config : NxDNeuronConfig , buckets : list [int ]
37- ):
38- bucket_degree = len (buckets )
39- if bucket_degree == 1 :
40- return None
41-
42- pad_token = config .pad_token_id
43-
44- # NOTE: KV Cache preprocessing is done within the model and not the
45- # shared buffer preprocessor due to lack of support of non-contiguous
46- # slicing of nrt tensors via the NRT API.
47- if tag == CONTEXT_ENCODING_MODEL_TAG :
48- return BucketModelConfig (
49- bucket_kernel = get_context_encoder_bk ,
50- bucket_kernel_constant_args = (
51- torch .tensor (buckets ),
52- pad_token ,
53- ),
54- shared_state_buffer = None ,
55- func_kwargs = [{"bucket_rank" : i } for i in range (bucket_degree )],
56- )
57- elif tag == TOKEN_GENERATION_MODEL_TAG or tag == SPECULATION_MODEL_TAG :
58- return BucketModelConfig (
59- bucket_kernel = get_generation_model_bk ,
60- bucket_kernel_constant_args = (
61- torch .tensor (buckets ),
62- 0 ,
63- ),
64- shared_state_buffer = None ,
65- func_kwargs = [{"bucket_rank" : i } for i in range (bucket_degree )],
66- )
67- else :
68- raise ValueError (
69- f"The supplied tag: { tag } is not supported for Bucketing. Only { CONTEXT_ENCODING_MODEL_TAG } and { TOKEN_GENERATION_MODEL_TAG } are supported"
70- )
71-
72-
7325class NxDDecoderBuilder (NxDGraphBuilder ):
7426 def __init__ (
7527 self ,
7628 config : PretrainedConfig ,
7729 neuron_config : NxDNeuronConfig ,
78- buckets : list [ int ] ,
79- bucket_n_active_tokens : bool ,
30+ max_tokens : int ,
31+ active_tokens : int ,
8032 model_cls ,
8133 tag = "" ,
8234 priority_model_idx : int = None ,
8335 ) -> None :
8436 super ().__init__ (tag , priority_model_idx )
8537 self .config = config
8638 self .neuron_config = neuron_config
87- self .buckets = buckets
88- self .bucket_n_active_tokens = bucket_n_active_tokens
39+ self .max_tokens = max_tokens
40+ self .active_tokens = active_tokens
8941
9042 if not self .neuron_config .torch_dtype :
9143 self .neuron_config .torch_dtype = torch .float32
@@ -99,18 +51,16 @@ def input_generator(
9951 self ,
10052 ):
10153 inputs = []
102- for bucket in self .buckets :
103- n_active_tokens = bucket if self .bucket_n_active_tokens else self .neuron_config .n_active_tokens
10454
105- input_ids = torch .zeros ((self .neuron_config .batch_size , n_active_tokens ), dtype = torch .int32 )
106- attention_mask = torch .zeros ((self .neuron_config .batch_size , bucket ), dtype = torch .int32 )
107- position_ids = torch .zeros ((self .neuron_config .batch_size , n_active_tokens ), dtype = torch .int32 )
108- seq_ids = torch .zeros ((self .neuron_config .batch_size ), dtype = torch .int32 )
109- # Get the count of sampling params currently supported.
110- sampling_params_len = prepare_sampling_params (1 ).shape [1 ]
111- sampling_params = torch .zeros ((self .neuron_config .batch_size , sampling_params_len ), dtype = torch .float32 )
55+ input_ids = torch .zeros ((self .neuron_config .batch_size , self . active_tokens ), dtype = torch .int32 )
56+ attention_mask = torch .zeros ((self .neuron_config .batch_size , self . max_tokens ), dtype = torch .int32 )
57+ position_ids = torch .zeros ((self .neuron_config .batch_size , self . active_tokens ), dtype = torch .int32 )
58+ seq_ids = torch .zeros ((self .neuron_config .batch_size ), dtype = torch .int32 )
59+ # Get the count of sampling params currently supported.
60+ sampling_params_len = prepare_sampling_params (1 ).shape [1 ]
61+ sampling_params = torch .zeros ((self .neuron_config .batch_size , sampling_params_len ), dtype = torch .float32 )
11262
113- inputs .append ((input_ids , attention_mask , position_ids , seq_ids , sampling_params ))
63+ inputs .append ((input_ids , attention_mask , position_ids , seq_ids , sampling_params ))
11464
11565 return inputs
11666
@@ -119,21 +69,18 @@ def get_model_instance(self):
11969 model_cls = self .model_cls ,
12070 config = self .config ,
12171 neuron_config = self .neuron_config ,
122- buckets = self .buckets ,
72+ n_positions = self .max_tokens ,
12373 )
12474
125- def get_bucket_config (self ):
126- return get_bucket_model_config_from_tag (self .tag , self .config , self .neuron_config , self .buckets )
127-
12875
12976class DecoderModelInstance (BaseModelInstance ):
130- def __init__ (self , model_cls , config : PretrainedConfig , neuron_config : NxDNeuronConfig , buckets : list [ int ] ):
77+ def __init__ (self , model_cls , config : PretrainedConfig , neuron_config : NxDNeuronConfig , n_positions : int ):
13178 self .model_cls = model_cls
13279 self .module = None
13380 self .input_output_aliases = None
13481 self .config = config
13582 self .neuron_config = neuron_config
136- self .buckets = buckets
83+ self .n_positions = n_positions
13784
13885 def initialize_process_group (self , world_size ):
13986 self .model_cls .initialize_process_group (world_size )
@@ -149,18 +96,12 @@ def load_module(self):
14996 else t
15097 )
15198 self .module = float_model
99+ self .module .n_positions = self .n_positions
152100
153101 def get (self , bucket_rank , ** kwargs ):
154- if bucket_rank is not None :
155- self .module .n_positions = self .buckets [bucket_rank ]
156-
157- # Currently we have to init an input_output_aliases map for
158- # each buckets, otherwise it will fail the aliasing setup when
159- # generating HLO
102+ assert bucket_rank == 0
160103 self .input_output_aliases = {}
161104 num_output_from_trace = 1 if not self .neuron_config .output_logits else 2
162- # TODO: This else block is a short-term fix for Llava/ViT models to use DecoderModelInstance.
163- # Long-term, these models should use a different implementation of BaseModelInstance.
164105 if self .module .kv_mgr is not None :
165106 past_key_values = self .module .kv_mgr .past_key_values
166107 else :
0 commit comments