@@ -33,22 +33,26 @@ def get_regalloc_signature_spec():
3333
3434 observation_spec = {
3535 key : tf .TensorSpec (dtype = tf .int64 , shape = (num_registers ), name = key )
36- for key in ('mask' , 'is_hint' , 'is_local' , 'is_free' )}
37- observation_spec .update (
38- {key : tensor_spec .BoundedTensorSpec (
39- dtype = tf .int64 ,
40- shape = (num_registers ),
41- name = key ,
42- minimum = 0 ,
43- maximum = 6 ) for key in ('max_stage' , 'min_stage' )})
44- observation_spec .update (
45- {key : tf .TensorSpec (dtype = tf .float32 , shape = (num_registers ), name = key )
46- for key in ('weighed_reads_by_max' , 'weighed_writes_by_max' ,
47- 'weighed_read_writes_by_max' , 'weighed_indvars_by_max' ,
48- 'hint_weights_by_max' , 'start_bb_freq_by_max' ,
49- 'end_bb_freq_by_max' , 'hottest_bb_freq_by_max' ,
50- 'liverange_size' , 'use_def_density' , 'nr_defs_and_uses' ,
51- 'nr_broken_hints' , 'nr_urgent' , 'nr_rematerializable' )})
36+ for key in ('mask' , 'is_hint' , 'is_local' , 'is_free' )
37+ }
38+ observation_spec .update ({
39+ key :
40+ tensor_spec .BoundedTensorSpec (
41+ dtype = tf .int64 ,
42+ shape = (num_registers ),
43+ name = key ,
44+ minimum = 0 ,
45+ maximum = 6 ) for key in ('max_stage' , 'min_stage' )
46+ })
47+ observation_spec .update ({
48+ key : tf .TensorSpec (dtype = tf .float32 , shape = (num_registers ), name = key )
49+ for key in ('weighed_reads_by_max' , 'weighed_writes_by_max' ,
50+ 'weighed_read_writes_by_max' , 'weighed_indvars_by_max' ,
51+ 'hint_weights_by_max' , 'start_bb_freq_by_max' ,
52+ 'end_bb_freq_by_max' , 'hottest_bb_freq_by_max' ,
53+ 'liverange_size' , 'use_def_density' , 'nr_defs_and_uses' ,
54+ 'nr_broken_hints' , 'nr_urgent' , 'nr_rematerializable' )
55+ })
5256 observation_spec ['progress' ] = tensor_spec .BoundedTensorSpec (
5357 dtype = tf .float32 , shape = (), name = 'progress' , minimum = 0 , maximum = 1 )
5458
0 commit comments