@@ -2004,68 +2004,19 @@ def warmup_scenario(self,
20042004 self .device )
20052005 # Dummy run.
20062006 htorch .core .mark_step ()
2007- logits = self ._execute_model_generic (input_ids_device ,
2008- position_ids_device ,
2009- attn_metadata ,
2010- logits_indices_device , kv_caches ,
2011- True )
2007+ _ = self ._execute_model_generic (input_ids_device , position_ids_device ,
2008+ attn_metadata , logits_indices_device ,
2009+ kv_caches , True )
20122010 # TODO: do sampling on logits, warmup sampler and prefill joiner
20132011 htorch .core .mark_step ()
20142012 temperature = torch .ones (batch_size , dtype = torch .float32 , device = 'cpu' )
20152013 top_p = torch .ones (batch_size , dtype = torch .float32 , device = 'cpu' )
20162014 top_k = torch .ones (batch_size , dtype = torch .float32 , device = 'cpu' )
2017- temperature_device = _async_h2d_tensor_copy (temperature , self .device )
2018- top_p_device = _async_h2d_tensor_copy (top_p , self .device )
2019- top_k_device = _async_h2d_tensor_copy (top_k , self .device )
2020- generators = {
2021- i : None
2022- for i in range (batch_size )
2023- } # NOTE(kzawora): idk what to set here
2024- max_num_logprobs = 0 # NOTE(kzawora): idk what to set here
2025- # NOTE(kzawora: do this in a smarter way)
2015+ _ = _async_h2d_tensor_copy (temperature , self .device )
2016+ _ = _async_h2d_tensor_copy (top_p , self .device )
2017+ _ = _async_h2d_tensor_copy (top_k , self .device )
20262018 self .profiler .end ()
20272019 return None
2028- htorch .core .mark_step ()
2029- sampling_metadata = SamplingMetadata (
2030- temperature = temperature_device ,
2031- all_greedy = False , # hacky
2032- all_random = True , # hacky
2033- top_p = top_p_device ,
2034- top_k = top_k_device ,
2035- no_top_p = True ,
2036- no_top_k = True ,
2037- generators = generators ,
2038- max_num_logprobs = max_num_logprobs ,
2039- )
2040- tokens_all_random = self .sampler (logits , sampling_metadata )
2041- htorch .core .mark_step ()
2042- sampling_metadata = SamplingMetadata (
2043- temperature = temperature_device ,
2044- all_greedy = True , # hacky
2045- all_random = False , # hacky
2046- top_p = top_p_device ,
2047- top_k = top_k_device ,
2048- no_top_p = True ,
2049- no_top_k = True ,
2050- generators = generators ,
2051- max_num_logprobs = max_num_logprobs ,
2052- )
2053- tokens_all_greedy = self .sampler (logits , sampling_metadata )
2054- htorch .core .mark_step ()
2055- sampling_metadata = SamplingMetadata (
2056- temperature = temperature_device ,
2057- all_greedy = False , # hacky
2058- all_random = False , # hacky
2059- top_p = top_p_device ,
2060- top_k = top_k_device ,
2061- no_top_p = True ,
2062- no_top_k = True ,
2063- generators = generators ,
2064- max_num_logprobs = max_num_logprobs ,
2065- )
2066- tokens_mixed = self .sampler (logits , sampling_metadata )
2067- htorch .core .mark_step ()
2068- return tokens_all_random , tokens_all_greedy , tokens_mixed
20692020
20702021 def log_warmup (self , phase , i , max_i , batch_size , seq_len , num_blocks ):
20712022 free_mem = format_bytes (
@@ -2351,30 +2302,6 @@ def __del__(self):
23512302 @torch .inference_mode ()
23522303 def profile_run (self ) -> None :
23532304 return
2354- """Profile to measure peak memory during forward pass."""
2355-
2356- # use an empty tensor instead of `None`` to force Dynamo to pass
2357- # it by reference, rather by specializing on the value `None`.
2358- # the `dtype` argument does not matter, and we use `float32` as
2359- # a placeholder (it has wide hardware support).
2360- # it is important to create tensors inside the loop, rather than
2361- # multiplying the list, to avoid Dynamo from treating them as
2362- # tensor aliasing.
2363- num_layers = self .model_config .get_num_layers (self .parallel_config )
2364- kv_caches = [None ] * num_layers
2365-
2366- # Run empty prefill forwards - prefill max batch and prefill max seq
2367- self .warmup_scenario (batch_size = 1 ,
2368- seq_or_block = self .max_model_len ,
2369- is_prompt = True ,
2370- kv_caches = kv_caches )
2371- max_seq_len = math .ceil (
2372- (self .max_num_tokens // self .max_prefill_batch_size ) /
2373- self .block_size ) * self .block_size
2374- self .warmup_scenario (batch_size = self .max_prefill_batch_size ,
2375- seq_or_block = max_seq_len ,
2376- is_prompt = True ,
2377- kv_caches = kv_caches )
23782305
23792306 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
23802307 """
0 commit comments