20
20
from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Generator , Union
21
21
import json
22
22
23
- # from absl import flags
23
+ from absl import flags
24
24
from absl import logging
25
25
import bisect
26
26
import dataclasses
46
46
from compiler_opt .distributed import buffered_scheduler
47
47
from compiler_opt .distributed .local import local_worker_manager
48
48
49
+ _PERSISTENT_OBJECTS_PATH = flags .DEFINE_string (
50
+ 'persistent_objects_path' , None ,
51
+ ('If specified, the temp compiled binaries throughout'
52
+ 'the trajectory generation will be saved in persistent_objects_path'
53
+ 'for linking the final binary.' ))
54
+
55
+ FLAGS = flags .FLAGS
56
+
49
57
ProfilingDictValueType = Dict [str , Union [str , float , int ]]
50
58
51
59
@@ -318,6 +326,7 @@ def __init__(
318
326
tensor_spec .BoundedTensorSpec ,
319
327
]] = None ,
320
328
reward_key : str = '' ,
329
+ explicit_temps_dir : Optional [str ] = None ,
321
330
** kwargs ,
322
331
):
323
332
self ._loaded_module_spec = loaded_module_spec
@@ -343,6 +352,7 @@ def __init__(
343
352
task_type = mlgo_task_type ,
344
353
obs_spec = obs_spec ,
345
354
action_spec = action_spec ,
355
+ explicit_temps_dir = explicit_temps_dir ,
346
356
interactive_only = True ,
347
357
)
348
358
if self ._env .action_spec :
@@ -603,8 +613,8 @@ def _process_obs(self, curr_obs, sequence_example):
603
613
class ModuleWorkerResultProcessor :
604
614
"""Utility class to process ModuleExplorer results for ModuleWorker."""
605
615
606
- def __init__ (self , base_path : Optional [str ] = None ):
607
- self ._base_path = base_path
616
+ def __init__ (self , persistent_objects_path : Optional [str ] = None ):
617
+ self ._persistent_objects_path = persistent_objects_path
608
618
609
619
def _partition_for_loss (self , seq_example : tf .train .SequenceExample ,
610
620
partitions : List [float ], label_name : str ):
@@ -654,12 +664,13 @@ def process_succeeded(
654
664
logging .info ('best policy idx: %s, best exploration idxs %s' ,
655
665
best_policy_idx , best_exploration_idxs )
656
666
657
- if self ._base_path :
667
+ if self ._persistent_objects_path :
658
668
# as long as we have one process handles one module this can stay here
659
669
temp_working_dir_idx = working_dir_list [best_policy_idx ][1 ]
660
670
temp_working_dir_list = working_dir_list [best_policy_idx ][0 ]
661
671
temp_working_dir = temp_working_dir_list [temp_working_dir_idx ]
662
- self ._save_binary (self ._base_path , spec_name , temp_working_dir )
672
+ self ._save_binary (self ._persistent_objects_path , spec_name ,
673
+ temp_working_dir )
663
674
664
675
self ._partition_for_loss (seq_example , partitions , label_name )
665
676
@@ -689,11 +700,12 @@ def _profiling_dict(
689
700
}
690
701
return per_module_dict
691
702
692
- def _save_binary (self , base_path : str , save_path : str , binary_path : str ):
703
+ def _save_binary (self , persistent_objects_path : str , save_path : str ,
704
+ binary_path : str ):
693
705
path_head_tail = os .path .split (save_path )
694
706
path_head = path_head_tail [0 ]
695
707
path_tail = path_head_tail [1 ]
696
- save_dir = os .path .join (base_path , path_head )
708
+ save_dir = os .path .join (persistent_objects_path , path_head )
697
709
if not os .path .exists (save_dir ):
698
710
os .makedirs (save_dir , exist_ok = True )
699
711
shutil .copy (
@@ -725,7 +737,8 @@ class ModuleWorker(worker.Worker):
725
737
explore_on_features: dict of feature names and functions which specify
726
738
when to explore on the respective feature
727
739
obs_action_specs: optional observation spec annotating TimeStep
728
- base_path: root path to save best compiled binaries for linking
740
+ persistent_objects_path: root path to save best compiled binaries
741
+ for linking
729
742
partitions: a tuple of limits defining the buckets, see partition_for_loss
730
743
env_args: additional arguments to pass to the ModuleExplorer, used in
731
744
creating the environment. This has to include the reward_key
@@ -748,7 +761,7 @@ def __init__(
748
761
time_step .TimeStep ,
749
762
tensor_spec .BoundedTensorSpec ,
750
763
]] = None ,
751
- base_path : Optional [str ] = None ,
764
+ persistent_objects_path : Optional [str ] = None ,
752
765
partitions : List [float ] = [
753
766
0. ,
754
767
],
@@ -775,8 +788,8 @@ def __init__(
775
788
[tf .Tensor ], bool ]]] = explore_on_features
776
789
self ._obs_action_specs : Optional [Tuple [
777
790
time_step .TimeStep , tensor_spec .BoundedTensorSpec ]] = obs_action_specs
778
- self ._mw_utility = ModuleWorkerResultProcessor (base_path )
779
- self ._base_path = base_path
791
+ self ._mw_utility = ModuleWorkerResultProcessor (persistent_objects_path )
792
+ self ._persistent_objects_path = persistent_objects_path
780
793
self ._partitions = partitions
781
794
self ._envargs = envargs
782
795
@@ -858,7 +871,7 @@ def select_best_exploration(
858
871
try :
859
872
shutil .rmtree (temp_dir_head )
860
873
except FileNotFoundError as e :
861
- if not self ._base_path :
874
+ if not self ._persistent_objects_path :
862
875
continue
863
876
else :
864
877
raise FileNotFoundError (
@@ -918,6 +931,13 @@ def gen_trajectories(
918
931
worker_manager_class: A pool of workers hosted on the local machines, each
919
932
in its own process.
920
933
"""
934
+ explicit_temps_dir = FLAGS .explicit_temps_dir
935
+ persistent_objects_path = _PERSISTENT_OBJECTS_PATH .value
936
+ if not explicit_temps_dir and persistent_objects_path :
937
+ logging .warning ('Setting explicit_temps_dir to persistent_objects_path=%s' ,
938
+ persistent_objects_path )
939
+ explicit_temps_dir = os .path .join (persistent_objects_path , 'temp_dirs' )
940
+
921
941
cps = corpus .Corpus (data_path = data_path , delete_flags = delete_flags )
922
942
logging .info ('Done loading module specs from corpus.' )
923
943
@@ -944,6 +964,8 @@ def gen_trajectories(
944
964
mlgo_task_type = mlgo_task_type ,
945
965
callable_policies = callable_policies ,
946
966
explore_on_features = explore_on_features ,
967
+ persistent_objects_path = persistent_objects_path ,
968
+ explicit_temps_dir = explicit_temps_dir ,
947
969
gin_config_str = gin .config_str (),
948
970
) as lwm :
949
971
0 commit comments