@@ -1058,7 +1058,6 @@ def observation_spec(self):
10581058 "multi_select" : (0 , len (UnitLayer )),
10591059 "player" : (len (Player ),),
10601060 "production_queue" : (0 , len (ProductionQueue )),
1061- "radar" : (0 , len (Radar )),
10621061 "score_cumulative" : (len (ScoreCumulative ),),
10631062 "score_by_category" : (len (ScoreByCategory ), len (ScoreCategories )),
10641063 "score_by_vital" : (len (ScoreByVital ), len (ScoreVitals )),
@@ -1094,6 +1093,9 @@ def observation_spec(self):
10941093 obs_spec ["raw_units" ] = (0 , len (FeatureUnit ))
10951094 obs_spec ["raw_effects" ] = (0 , len (EffectPos ))
10961095
1096+ if aif .use_feature_units or aif .use_raw_units :
1097+ obs_spec ["radar" ] = (0 , len (Radar ))
1098+
10971099 obs_spec ["upgrades" ] = (0 ,)
10981100
10991101 if aif .use_unit_counts :
@@ -1526,12 +1528,13 @@ def cargo_units(u, pos_transform, is_raw=False):
15261528 if player_id != player .player_id :
15271529 out ["away_race_requested" ] = np .array ((race ,), dtype = np .int32 )
15281530
1529- def transform_radar (radar ):
1530- p = self ._world_to_minimap_px .fwd_pt (point .Point .build (radar .pos ))
1531- return p .x , p .y , radar .radius
1532- out ["radar" ] = named_array .NamedNumpyArray (
1533- list (map (transform_radar , obs .observation .raw_data .radar )),
1534- [None , Radar ], dtype = np .int32 )
1531+ if aif .use_feature_units or aif .use_raw_units :
1532+ def transform_radar (radar ):
1533+ p = self ._world_to_minimap_px .fwd_pt (point .Point .build (radar .pos ))
1534+ return p .x , p .y , radar .radius
1535+ out ["radar" ] = named_array .NamedNumpyArray (
1536+ list (map (transform_radar , obs .observation .raw_data .radar )),
1537+ [None , Radar ], dtype = np .int32 )
15351538
15361539 # Send the entire proto as well (in a function, so it isn't copied).
15371540 if self ._send_observation_proto :
0 commit comments