-
Couldn't load subscription status.
- Fork 25
Bump transformers and torch #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
3a960bb
3d223a2
207f8b1
bc82841
35fc918
6a26464
4d68263
6a3e1d4
2b5fe7e
bb0089c
72802e3
9876c7e
19f4d21
99805f8
108ed17
59778eb
ff8a2a1
a3009ca
ae488b1
b7a2fa1
1e0a671
896f0da
abd641b
7f7f9c2
5f8a56f
92bc2ba
4abb2ec
ad9b639
b252038
e135310
671bc06
70338e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,12 +54,12 @@ def __init__( | |
|
|
||
| # Create a list of CustomKVCache instances, one per layer | ||
| self.kv_cache = torch.nn.ModuleList() | ||
| for _ in range(config.num_hidden_layers): | ||
| for layer in self.layers: | ||
| layer_cache = CustomKVCache( | ||
| max_batch_size=self.max_batch_size, | ||
| max_context_length=self.max_cache_len, | ||
| n_heads=self.num_key_value_heads, | ||
| head_dim=self.head_dim, | ||
| max_batch_size=layer.max_batch_size, | ||
| max_context_length=layer.max_cache_len, | ||
| n_heads=layer.num_heads, | ||
| head_dim=layer.head_dim, | ||
| dtype=dtype, | ||
| ) | ||
| self.kv_cache.append(layer_cache) | ||
|
|
@@ -202,32 +202,29 @@ def __init__( | |
| layer_device_map=layer_device_map, | ||
| ) | ||
|
|
||
| # make sure layer_device_map is none | ||
| assert layer_device_map is None | ||
| assert device is None or device == "cpu", "Device must be None or 'cpu'" | ||
|
|
||
| self.cache_position = None | ||
| # Create a list of cache instances, one per layer | ||
| # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers | ||
| # Create a list of cache instances, one per layer. | ||
| # Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers. | ||
| self.kv_cache = torch.nn.ModuleList() | ||
| for layer_idx in range(config.num_hidden_layers): | ||
| # newer version of transfomer has is_sliding defined | ||
| # for HybridCache | ||
| if self.is_sliding[layer_idx]: | ||
| for layer in self.layers: | ||
| if layer.is_sliding: | ||
| # This is a sliding window layer | ||
| layer_cache = CustomRingKVCache( | ||
| max_batch_size=self.max_batch_size, | ||
| max_context_length=self.sliding_window_len, | ||
| n_heads=self.num_key_value_heads, | ||
| head_dim=self.head_dim, | ||
| max_batch_size=layer.max_batch_size, | ||
| max_context_length=layer.max_cache_len, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait what is happening here? is this same as sliding_window_len There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah they removed https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L357 |
||
| n_heads=layer.num_heads, | ||
| head_dim=layer.head_dim, | ||
| dtype=dtype, | ||
| ) | ||
| else: | ||
| layer_cache = CustomKVCache( | ||
| max_batch_size=self.max_batch_size, | ||
| max_context_length=self.max_cache_len, | ||
| n_heads=self.num_key_value_heads, | ||
| head_dim=self.head_dim, | ||
| max_batch_size=layer.max_batch_size, | ||
| max_context_length=layer.max_cache_len, | ||
| n_heads=layer.num_heads, | ||
| head_dim=layer.head_dim, | ||
| dtype=dtype, | ||
| ) | ||
| self.kv_cache.append(layer_cache) | ||
|
|
@@ -284,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |
|
|
||
| # For CustomRingKVCache, we need to handle the sequence length differently | ||
| layer_cache = self.kv_cache[layer_idx] | ||
| if self.is_sliding[layer_idx]: | ||
| if self.layers[layer_idx].is_sliding: | ||
| # CustomRingKVCache cache_position_manager which | ||
| # maintains cache position for each slot in the kv cache | ||
| # we return the max position + 1 to indicate max position | ||
|
|
@@ -308,7 +305,7 @@ def get_layer_cache(self, layer_idx: int): | |
|
|
||
| def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype): | ||
| """ | ||
| Replace all KV caches in the module with ETCustomStaticCache. | ||
| Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache. | ||
| This modifies the model in place. | ||
|
|
||
| Args: | ||
|
|
@@ -342,18 +339,18 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt | |
| if getattr(module, "replace_cache", None) is not None: | ||
| static_cache = ETCustomStaticCache( | ||
| config=config, | ||
| max_batch_size=generation_config.cache_config.batch_size, | ||
| max_cache_len=generation_config.cache_config.max_cache_len, | ||
| device=generation_config.cache_config.device, | ||
| max_batch_size=generation_config.cache_config.get("batch_size"), | ||
| max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
| device=generation_config.cache_config.get("device"), | ||
| dtype=cache_dtype, | ||
| ) | ||
| module.replace_cache(static_cache) | ||
| else: | ||
| module.static_cache = ETCustomStaticCache( | ||
| config=config, | ||
| max_batch_size=generation_config.cache_config.batch_size, | ||
| max_cache_len=generation_config.cache_config.max_cache_len, | ||
| device=generation_config.cache_config.device, | ||
| max_batch_size=generation_config.cache_config.get("batch_size"), | ||
| max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
| device=generation_config.cache_config.get("device"), | ||
| dtype=cache_dtype, | ||
| ) | ||
| # Dont know why we need to this even though | ||
|
|
@@ -370,25 +367,25 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt | |
| if getattr(module, "replace_cache", None) is not None: | ||
| hybrid_cache = ETCustomHybridCache( | ||
| config=config, | ||
| max_batch_size=generation_config.cache_config.batch_size, | ||
| max_cache_len=generation_config.cache_config.max_cache_len, | ||
| device=generation_config.cache_config.device, | ||
| max_batch_size=generation_config.cache_config.get("batch_size"), | ||
| max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
| device=generation_config.cache_config.get("device"), | ||
| dtype=cache_dtype, | ||
| ) | ||
| module.replace_cache(hybrid_cache) | ||
| else: | ||
| module.cache = ETCustomHybridCache( | ||
| config=config, | ||
| max_batch_size=generation_config.cache_config.batch_size, | ||
| max_cache_len=generation_config.cache_config.max_cache_len, | ||
| device=generation_config.cache_config.device, | ||
| max_batch_size=generation_config.cache_config.get("batch_size"), | ||
| max_cache_len=generation_config.cache_config.get("max_cache_len"), | ||
| device=generation_config.cache_config.get("device"), | ||
| dtype=cache_dtype, | ||
| ) | ||
| # Register cache attributes for each layer | ||
| for i in range(len(module.cache.kv_cache)): | ||
| setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache) | ||
| setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache) | ||
| if module.cache.is_sliding[i]: | ||
| if module.cache.layers[i].is_sliding: | ||
| # Register cache_positions as buffer for sliding window layers | ||
| # This prevents it from being traced as a constant | ||
| module.register_buffer( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened here? like config doesnt exist anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still exists, feel like it's more idiomatic to iterate over the actual layers