@@ -146,7 +146,7 @@ def update(
146146 input_pos : torch .Tensor ,
147147 k_slice : torch .Tensor ,
148148 v_slice : torch .Tensor ,
149- enable_hlfb : bool = True ,
149+ use_dus : bool = True ,
150150) -> KVCacheEntry :
151151 """Out of place update of Cache buffer.
152152
@@ -155,17 +155,14 @@ def update(
155155 input_pos (torch.Tensor): The update slice positions.
156156 k_slice (torch.Tensor): The K slice to be updated in the new cache.
157157 v_slice (torch.Tensor): The V slice to be updated in the new cache.
158- enable_hlfb (bool, optional): Whether the op is annotated for export with
159- High Level Function Boundary. Defaults to True.
160158
161159 Returns:
162160 KVCacheEntry: The updated KVCache entry based on the passed inputs.
163161 """
164- # Don't enable HLFB for kv cache op for now, since it won't work with LLM
165- # inference engine. Remove this part once we ship a new LLM inference engine.
166- enable_hlfb = False
167- update_func = _update_kv_hlfb_impl if enable_hlfb else _update_kv_base_impl
168- return update_func (cache , input_pos , k_slice , v_slice )
162+ # Turn dynamic_update_slice updates off for now.
163+ use_dus = False
164+ update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165+ return update_kv_cache (cache , input_pos , k_slice , v_slice )
169166
170167
171168def _update_kv_base_impl (
@@ -181,18 +178,28 @@ def _update_kv_base_impl(
181178 return updated_cache
182179
183180
184- def _update_kv_hlfb_impl (
181+ def _get_slice_indices (positions : torch .Tensor ) -> torch .Tensor :
182+ """Dynamic Update Slice updates are a variadic sequence of 0-rank tensors."""
183+
184+ zero = torch .zeros ([]).int ()
185+ positions = positions .int ()[0 ].reshape ([])
186+ return [zero , positions , zero , zero ]
187+
188+
189+ def _update_kv_impl (
185190 cache : KVCacheEntry ,
186191 input_pos : torch .Tensor ,
187192 k_slice : torch .Tensor ,
188193 v_slice : torch .Tensor ,
189194) -> KVCacheEntry :
190- """Update the cache buffer with High Level Function Boundary annotation."""
191- builder = hlfb .StableHLOCompositeBuilder (name = "odml.update_external_kv_cache" )
192- k_cache , v_cache , input_pos , k_slice , v_slice = builder .mark_inputs (
193- cache .k_cache , cache .v_cache , input_pos , k_slice , v_slice
194- )
195- k = k_cache .index_copy (1 , input_pos .to (torch .long ), k_slice )
196- v = v_cache .index_copy (1 , input_pos .to (torch .long ), v_slice )
197- k , v = builder .mark_outputs (k , v )
198- return KVCacheEntry (k , v )
195+ """Update the cache buffer for K and V caches."""
196+ # NB: Here assume that input_pos == range(input_pos[0], len(input_pos))
197+
198+ k_slice_indices = _get_slice_indices (input_pos )
199+ v_slice_indices = _get_slice_indices (input_pos )
200+
201+ k = dynamic_update_slice (cache .k_cache , k_slice , k_slice_indices )
202+ v = dynamic_update_slice (cache .v_cache , v_slice , v_slice_indices )
203+
204+ updated_cache = KVCacheEntry (k , v )
205+ return updated_cache
0 commit comments