Skip to content

Commit 71aded3

Browse files
committed
refactor: explicitly remove bucketing
1 parent 08bc221 commit 71aded3

File tree

7 files changed

+23
-250
lines changed

7 files changed

+23
-250
lines changed

optimum/neuron/models/inference/backend/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def __init__(
6868
max_context_length: int | None = None,
6969
output_logits: bool | None = False,
7070
fused_qkv: bool | None = False,
71-
enable_bucketing: bool | None = False,
7271
target: str | None = None, # set to "trn2" for trn2
7372
on_device_sampling: bool | None = False,
7473
max_topk: int | None = 256,
@@ -106,9 +105,6 @@ def __init__(
106105
self.on_device_sampling = on_device_sampling
107106
self.max_topk = max_topk
108107

109-
# Bucketing
110-
self.enable_bucketing = enable_bucketing
111-
112108
# Speculative decoding
113109
self.speculation_length = speculation_length
114110
if self.speculation_length > 0:

optimum/neuron/models/inference/backend/graph_builder.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
from neuronx_distributed.trace.model_builder import BaseModelInstance
19-
from torch_neuronx import BucketModelConfig
2019

2120

2221
class NxDGraphBuilder(ABC):
@@ -40,11 +39,3 @@ def get_model_instance(self) -> BaseModelInstance:
4039
Used at compilation time only when tracing the model.
4140
"""
4241
raise NotImplementedError
43-
44-
@abstractmethod
45-
def get_bucket_config(self) -> BucketModelConfig | None:
46-
"""Return the bucket configuration
47-
48-
Used at compilation time only when tracing the model.
49-
"""
50-
raise NotImplementedError

optimum/neuron/models/inference/backend/modules/attention/attention_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ def get_flash_attention_strategy(self, q_len) -> FlashAttentionStrategy:
266266
These constraints may change later.
267267
268268
TODO: Throw an exception instead of disabling flash attention if explicitly enabled but not eligible.
269-
This must consider bucketing to avoid throwing an exception for smaller buckets.
270269
"""
271270
if self._qk_scale is not None:
272271
# If a custom qk_scale is provided, flash attention is not supported.

optimum/neuron/models/inference/backend/modules/autobucketing.py

Lines changed: 0 additions & 131 deletions
This file was deleted.

optimum/neuron/models/inference/backend/modules/decoder/decoder_builder.py

Lines changed: 17 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,77 +15,29 @@
1515

1616
import torch
1717
from neuronx_distributed.trace.model_builder import BaseModelInstance
18-
from torch_neuronx import BucketModelConfig
1918
from transformers import PretrainedConfig
2019

2120
from ...config import NxDNeuronConfig
2221
from ...graph_builder import NxDGraphBuilder
23-
from ..autobucketing import (
24-
get_context_encoder_bk,
25-
get_generation_model_bk,
26-
)
2722
from ..generation.sampling import prepare_sampling_params
2823

2924

30-
CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model"
31-
TOKEN_GENERATION_MODEL_TAG = "token_generation_model"
32-
SPECULATION_MODEL_TAG = "speculation_model"
33-
34-
35-
def get_bucket_model_config_from_tag(
36-
tag, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: list[int]
37-
):
38-
bucket_degree = len(buckets)
39-
if bucket_degree == 1:
40-
return None
41-
42-
pad_token = config.pad_token_id
43-
44-
# NOTE: KV Cache preprocessing is done within the model and not the
45-
# shared buffer preprocessor due to lack of support of non-contiguous
46-
# slicing of nrt tensors via the NRT API.
47-
if tag == CONTEXT_ENCODING_MODEL_TAG:
48-
return BucketModelConfig(
49-
bucket_kernel=get_context_encoder_bk,
50-
bucket_kernel_constant_args=(
51-
torch.tensor(buckets),
52-
pad_token,
53-
),
54-
shared_state_buffer=None,
55-
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
56-
)
57-
elif tag == TOKEN_GENERATION_MODEL_TAG or tag == SPECULATION_MODEL_TAG:
58-
return BucketModelConfig(
59-
bucket_kernel=get_generation_model_bk,
60-
bucket_kernel_constant_args=(
61-
torch.tensor(buckets),
62-
0,
63-
),
64-
shared_state_buffer=None,
65-
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
66-
)
67-
else:
68-
raise ValueError(
69-
f"The supplied tag: {tag} is not supported for Bucketing. Only {CONTEXT_ENCODING_MODEL_TAG} and {TOKEN_GENERATION_MODEL_TAG} are supported"
70-
)
71-
72-
7325
class NxDDecoderBuilder(NxDGraphBuilder):
7426
def __init__(
7527
self,
7628
config: PretrainedConfig,
7729
neuron_config: NxDNeuronConfig,
78-
buckets: list[int],
79-
bucket_n_active_tokens: bool,
30+
max_tokens: int,
31+
active_tokens: int,
8032
model_cls,
8133
tag="",
8234
priority_model_idx: int = None,
8335
) -> None:
8436
super().__init__(tag, priority_model_idx)
8537
self.config = config
8638
self.neuron_config = neuron_config
87-
self.buckets = buckets
88-
self.bucket_n_active_tokens = bucket_n_active_tokens
39+
self.max_tokens = max_tokens
40+
self.active_tokens = active_tokens
8941

9042
if not self.neuron_config.torch_dtype:
9143
self.neuron_config.torch_dtype = torch.float32
@@ -99,18 +51,16 @@ def input_generator(
9951
self,
10052
):
10153
inputs = []
102-
for bucket in self.buckets:
103-
n_active_tokens = bucket if self.bucket_n_active_tokens else self.neuron_config.n_active_tokens
10454

105-
input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
106-
attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32)
107-
position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
108-
seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32)
109-
# Get the count of sampling params currently supported.
110-
sampling_params_len = prepare_sampling_params(1).shape[1]
111-
sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32)
55+
input_ids = torch.zeros((self.neuron_config.batch_size, self.active_tokens), dtype=torch.int32)
56+
attention_mask = torch.zeros((self.neuron_config.batch_size, self.max_tokens), dtype=torch.int32)
57+
position_ids = torch.zeros((self.neuron_config.batch_size, self.active_tokens), dtype=torch.int32)
58+
seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32)
59+
# Get the count of sampling params currently supported.
60+
sampling_params_len = prepare_sampling_params(1).shape[1]
61+
sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32)
11262

113-
inputs.append((input_ids, attention_mask, position_ids, seq_ids, sampling_params))
63+
inputs.append((input_ids, attention_mask, position_ids, seq_ids, sampling_params))
11464

11565
return inputs
11666

@@ -119,21 +69,18 @@ def get_model_instance(self):
11969
model_cls=self.model_cls,
12070
config=self.config,
12171
neuron_config=self.neuron_config,
122-
buckets=self.buckets,
72+
n_positions=self.max_tokens,
12373
)
12474

125-
def get_bucket_config(self):
126-
return get_bucket_model_config_from_tag(self.tag, self.config, self.neuron_config, self.buckets)
127-
12875

12976
class DecoderModelInstance(BaseModelInstance):
130-
def __init__(self, model_cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: list[int]):
77+
def __init__(self, model_cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig, n_positions: int):
13178
self.model_cls = model_cls
13279
self.module = None
13380
self.input_output_aliases = None
13481
self.config = config
13582
self.neuron_config = neuron_config
136-
self.buckets = buckets
83+
self.n_positions = n_positions
13784

13885
def initialize_process_group(self, world_size):
13986
self.model_cls.initialize_process_group(world_size)
@@ -149,18 +96,12 @@ def load_module(self):
14996
else t
15097
)
15198
self.module = float_model
99+
self.module.n_positions = self.n_positions
152100

153101
def get(self, bucket_rank, **kwargs):
154-
if bucket_rank is not None:
155-
self.module.n_positions = self.buckets[bucket_rank]
156-
157-
# Currently we have to init an input_output_aliases map for
158-
# each buckets, otherwise it will fail the aliasing setup when
159-
# generating HLO
102+
assert bucket_rank == 0
160103
self.input_output_aliases = {}
161104
num_output_from_trace = 1 if not self.neuron_config.output_logits else 2
162-
# TODO: This else block is a short-term fix for Llava/ViT models to use DecoderModelInstance.
163-
# Long-term, these models should use a different implementation of BaseModelInstance.
164105
if self.module.kv_mgr is not None:
165106
past_key_values = self.module.kv_mgr.past_key_values
166107
else:

0 commit comments

Comments
 (0)