Skip to content

Commit 9598a88

Browse files
committed
Remove hard coded configs
1 parent 321960a commit 9598a88

File tree

1 file changed

+2
-19
lines changed

1 file changed

+2
-19
lines changed

angel_system/global_step_prediction/global_step_predictor.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
self.activity_conf_history = np.empty((0, num_activity_classes))
6464

6565
self.recipe_types = recipe_types
66+
self.recipe_configs = recipe_config_dict
6667

6768
# Array of tracker dicts
6869
self.trackers = []
@@ -91,8 +92,6 @@ def __init__(
9192
)
9293
)
9394

94-
self.recipe_configs = recipe_config_dict
95-
9695
self.gt_activities_order_from_each_config = {
9796
_recipe: self.get_activity_order_from_config(self.recipe_configs[_recipe])
9897
for _recipe in self.recipe_configs
@@ -195,23 +194,7 @@ def initialize_new_recipe_tracker(self, recipe, config_fn=None):
195194
196195
"""
197196
tracker_dict = {}
198-
if recipe == "coffee":
199-
if config_fn == None:
200-
config_fn = "config/tasks/recipe_coffee.yaml"
201-
elif recipe == "tea":
202-
if config_fn == None:
203-
config_fn = "config/tasks/recipe_tea.yaml"
204-
elif recipe == "dessert_quesadilla":
205-
if config_fn == None:
206-
config_fn = "config/tasks/recipe_dessertquesadilla.yaml"
207-
elif recipe == "oatmeal":
208-
if config_fn == None:
209-
config_fn = "config/tasks/recipe_oatmeal.yaml"
210-
elif recipe == "pinwheel":
211-
if config_fn == None:
212-
config_fn = "config/tasks/recipe_pinwheel.yaml"
213-
else:
214-
raise ValueError(f"Invalid recipe type. Valid types: [coffee].")
197+
config_fn = self.recipe_configs[recipe]
215198

216199
# Read in task config
217200
with open(config_fn, "r") as stream:

0 commit comments

Comments
 (0)