22import json
33import os
44from collections import defaultdict
5- from typing import Any , Dict , List , Optional , Tuple
5+ from typing import Any
66
77import numpy as np
88
@@ -13,7 +13,7 @@ def configure_sta(mode: str = 'STA_searching',
1313 layer_num : int = 40 ,
1414 time_step_num : int = 50 ,
1515 head_num : int = 40 ,
16- ** kwargs ) -> List [ List [ List [Any ]]]:
16+ ** kwargs ) -> list [ list [ list [Any ]]]:
1717 """
1818 Configure Sliding Tile Attention (STA) parameters based on the specified mode.
1919
@@ -53,22 +53,22 @@ def configure_sta(mode: str = 'STA_searching',
5353
5454 if mode == 'STA_searching' :
5555 # Get parameters with defaults
56- mask_candidates : Optional [ List [ str ]] = kwargs .get ('mask_candidates' )
56+ mask_candidates : list [ str ] | None = kwargs .get ('mask_candidates' )
5757 if mask_candidates is None :
5858 raise ValueError (
5959 "mask_candidates is required for STA_searching mode" )
60- mask_selected : List [int ] = kwargs .get ('mask_selected' ,
60+ mask_selected : list [int ] = kwargs .get ('mask_selected' ,
6161 list (range (len (mask_candidates ))))
6262
6363 # Parse selected masks
64- selected_masks : List [ List [int ]] = []
64+ selected_masks : list [ list [int ]] = []
6565 for index in mask_selected :
6666 mask = mask_candidates [index ]
6767 masks_list = [int (x ) for x in mask .split (',' )]
6868 selected_masks .append (masks_list )
6969
7070 # Create 3D mask structure with fixed dimensions (t=50, l=60)
71- masks_3d : List [ List [ List [ List [int ]]]] = []
71+ masks_3d : list [ list [ list [ list [int ]]]] = []
7272 for i in range (time_step_num ): # Fixed t dimension = 50
7373 row = []
7474 for j in range (layer_num ): # Fixed l dimension = 60
@@ -79,25 +79,23 @@ def configure_sta(mode: str = 'STA_searching',
7979
8080 elif mode == 'STA_tuning' :
8181 # Get required parameters
82- mask_search_files_path : Optional [ str ] = kwargs .get (
82+ mask_search_files_path : str | None = kwargs .get (
8383 'mask_search_files_path' )
8484 if not mask_search_files_path :
8585 raise ValueError (
8686 "mask_search_files_path is required for STA_tuning mode" )
8787
8888 # Get optional parameters with defaults
89- mask_candidates_tuning : Optional [List [str ]] = kwargs .get (
90- 'mask_candidates' )
89+ mask_candidates_tuning : list [str ] | None = kwargs .get ('mask_candidates' )
9190 if mask_candidates_tuning is None :
9291 raise ValueError ("mask_candidates is required for STA_tuning mode" )
93- mask_selected_tuning : List [int ] = kwargs .get (
92+ mask_selected_tuning : list [int ] = kwargs .get (
9493 'mask_selected' , list (range (len (mask_candidates_tuning ))))
95- skip_time_steps_tuning : Optional [int ] = kwargs .get ('skip_time_steps' )
96- save_dir_tuning : Optional [str ] = kwargs .get ('save_dir' ,
97- "mask_candidates" )
94+ skip_time_steps_tuning : int | None = kwargs .get ('skip_time_steps' )
95+ save_dir_tuning : str | None = kwargs .get ('save_dir' , "mask_candidates" )
9896
9997 # Parse selected masks
100- selected_masks_tuning : List [ List [int ]] = []
98+ selected_masks_tuning : list [ list [int ]] = []
10199 for index in mask_selected_tuning :
102100 mask = mask_candidates_tuning [index ]
103101 masks_list = [int (x ) for x in mask .split (',' )]
@@ -108,7 +106,7 @@ def configure_sta(mode: str = 'STA_searching',
108106 averaged_results = average_head_losses (results , selected_masks_tuning )
109107
110108 # Add full attention mask for specific cases
111- full_attention_mask_tuning : Optional [ List [ int ]] = kwargs .get (
109+ full_attention_mask_tuning : list [ int ] | None = kwargs .get (
112110 'full_attention_mask' )
113111 if full_attention_mask_tuning is not None :
114112 selected_masks_tuning .append (full_attention_mask_tuning )
@@ -149,28 +147,28 @@ def configure_sta(mode: str = 'STA_searching',
149147 return mask_strategy_3d
150148 elif mode == 'STA_tuning_cfg' :
151149 # Get required parameters for both positive and negative paths
152- mask_search_files_path_pos : Optional [ str ] = kwargs .get (
150+ mask_search_files_path_pos : str | None = kwargs .get (
153151 'mask_search_files_path_pos' )
154- mask_search_files_path_neg : Optional [ str ] = kwargs .get (
152+ mask_search_files_path_neg : str | None = kwargs .get (
155153 'mask_search_files_path_neg' )
156- save_dir_cfg : Optional [ str ] = kwargs .get ('save_dir' )
154+ save_dir_cfg : str | None = kwargs .get ('save_dir' )
157155
158156 if not mask_search_files_path_pos or not mask_search_files_path_neg or not save_dir_cfg :
159157 raise ValueError (
160158 "mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode"
161159 )
162160
163161 # Get optional parameters with defaults
164- mask_candidates_cfg : Optional [ List [ str ]] = kwargs .get ('mask_candidates' )
162+ mask_candidates_cfg : list [ str ] | None = kwargs .get ('mask_candidates' )
165163 if mask_candidates_cfg is None :
166164 raise ValueError (
167165 "mask_candidates is required for STA_tuning_cfg mode" )
168- mask_selected_cfg : List [int ] = kwargs .get (
166+ mask_selected_cfg : list [int ] = kwargs .get (
169167 'mask_selected' , list (range (len (mask_candidates_cfg ))))
170- skip_time_steps_cfg : Optional [ int ] = kwargs .get ('skip_time_steps' )
168+ skip_time_steps_cfg : int | None = kwargs .get ('skip_time_steps' )
171169
172170 # Parse selected masks
173- selected_masks_cfg : List [ List [int ]] = []
171+ selected_masks_cfg : list [ list [int ]] = []
174172 for index in mask_selected_cfg :
175173 mask = mask_candidates_cfg [index ]
176174 masks_list = [int (x ) for x in mask .split (',' )]
@@ -187,7 +185,7 @@ def configure_sta(mode: str = 'STA_searching',
187185 selected_masks_cfg )
188186
189187 # Add full attention mask for specific cases
190- full_attention_mask_cfg : Optional [ List [ int ]] = kwargs .get (
188+ full_attention_mask_cfg : list [ int ] | None = kwargs .get (
191189 'full_attention_mask' )
192190 if full_attention_mask_cfg is not None :
193191 selected_masks_cfg .append (full_attention_mask_cfg )
@@ -227,7 +225,7 @@ def configure_sta(mode: str = 'STA_searching',
227225
228226 else : # STA_inference
229227 # Get parameters with defaults
230- load_path : Optional [ str ] = kwargs .get (
228+ load_path : str | None = kwargs .get (
231229 'load_path' , "mask_candidates/mask_strategy.json" )
232230 if load_path is None :
233231 raise ValueError ("load_path is required for STA_inference mode" )
@@ -248,9 +246,9 @@ def configure_sta(mode: str = 'STA_searching',
248246# Helper functions
249247
250248
251- def read_specific_json_files (folder_path : str ) -> List [ Dict [str , Any ]]:
249+ def read_specific_json_files (folder_path : str ) -> list [ dict [str , Any ]]:
252250 """Read and parse JSON files containing mask search results."""
253- json_contents : List [ Dict [str , Any ]] = []
251+ json_contents : list [ dict [str , Any ]] = []
254252
255253 # List files only in the current directory (no walk)
256254 files = os .listdir (folder_path )
@@ -268,11 +266,11 @@ def read_specific_json_files(folder_path: str) -> List[Dict[str, Any]]:
268266
269267
270268def average_head_losses (
271- results : List [ Dict [str , Any ]],
272- selected_masks : List [ List [int ]]) -> Dict [str , Dict [str , np .ndarray ]]:
269+ results : list [ dict [str , Any ]],
270+ selected_masks : list [ list [int ]]) -> dict [str , dict [str , np .ndarray ]]:
273271 """Average losses across all prompts for each mask strategy."""
274272 # Initialize a dictionary to store the averaged results
275- averaged_losses : Dict [str , Dict [str , np .ndarray ]] = {}
273+ averaged_losses : dict [str , dict [str , np .ndarray ]] = {}
276274 loss_type = 'L2_loss'
277275 # Get all loss types (e.g., 'L2_loss')
278276 averaged_losses [loss_type ] = {}
@@ -294,14 +292,14 @@ def average_head_losses(
294292
295293
296294def select_best_mask_strategy (
297- averaged_results : Dict [str , Dict [str , np .ndarray ]],
298- selected_masks : List [ List [int ]],
295+ averaged_results : dict [str , dict [str , np .ndarray ]],
296+ selected_masks : list [ list [int ]],
299297 skip_time_steps : int = 12 ,
300298 timesteps : int = 50 ,
301299 head_num : int = 40
302- ) -> Tuple [ Dict [str , List [int ]], float , Dict [str , int ]]:
300+ ) -> tuple [ dict [str , list [int ]], float , dict [str , int ]]:
303301 """Select the best mask strategy for each head based on loss minimization."""
304- best_mask_strategy : Dict [str , List [int ]] = {}
302+ best_mask_strategy : dict [str , list [int ]] = {}
305303 loss_type = 'L2_loss'
306304 # Get the shape of time steps and layers
307305 layers = len (averaged_results [loss_type ][str (selected_masks [0 ])][0 ])
@@ -310,7 +308,7 @@ def select_best_mask_strategy(
310308 total_tokens = 0 # total number of masked tokens
311309 total_length = 0 # total sequence length
312310
313- strategy_counts : Dict [str , int ] = {
311+ strategy_counts : dict [str , int ] = {
314312 str (strategy ): 0
315313 for strategy in selected_masks
316314 }
@@ -352,22 +350,22 @@ def select_best_mask_strategy(
352350
353351
354352def save_mask_search_results (
355- mask_search_final_result : List [ Dict [str , List [float ]]],
353+ mask_search_final_result : list [ dict [str , list [float ]]],
356354 prompt : str ,
357- mask_strategies : List [str ],
358- output_dir : str = 'output/mask_search_result/' ) -> Optional [ str ] :
355+ mask_strategies : list [str ],
356+ output_dir : str = 'output/mask_search_result/' ) -> str | None :
359357 if not mask_search_final_result :
360358 print ("No mask search results to save" )
361359 return None
362360
363361 # Create result dictionary with defaultdict for nested lists
364- mask_search_dict : Dict [str , Dict [str , List [ List [float ]]]] = {
362+ mask_search_dict : dict [str , dict [str , list [ list [float ]]]] = {
365363 "L2_loss" : defaultdict (list ),
366364 "L1_loss" : defaultdict (list )
367365 }
368366
369367 mask_selected = list (range (len (mask_strategies )))
370- selected_masks : List [ List [int ]] = []
368+ selected_masks : list [ list [int ]] = []
371369 for index in mask_selected :
372370 mask = mask_strategies [index ]
373371 masks_list = [int (x ) for x in mask .split (',' )]
@@ -377,7 +375,7 @@ def save_mask_search_results(
377375 for i , mask_strategy in enumerate (selected_masks ):
378376 mask_strategy_str = str (mask_strategy )
379377 # Process L2 loss
380- step_results : List [ List [float ]] = []
378+ step_results : list [ list [float ]] = []
381379 for step_data in mask_search_final_result :
382380 if isinstance (step_data , dict ) and "L2_loss" in step_data :
383381 layer_losses = [float (loss ) for loss in step_data ["L2_loss" ]]
0 commit comments