Skip to content

Commit 36b4156

Browse files
Add config utils
Signed-off-by: Thara Palanivel <[email protected]>
1 parent aa6ef46 commit 36b4156

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

fms_mo/utils/config_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import base64
17+
import json
18+
import os
19+
import pickle
20+
21+
22+
def update_config(config, **kwargs):
23+
if isinstance(config, (tuple, list)):
24+
for c in config:
25+
update_config(c, **kwargs)
26+
else:
27+
for k, v in kwargs.items():
28+
if hasattr(config, k):
29+
setattr(config, k, v)
30+
elif "." in k:
31+
# allow --some_config.some_param=True
32+
config_name, param_name = k.split(".")
33+
if type(config).__name__ == config_name:
34+
if hasattr(config, param_name):
35+
setattr(config, param_name, v)
36+
else:
37+
# In case of specialized config we can warm user
38+
print(f"Warning: {config_name} does not accept parameter: {k}")
39+
40+
41+
def get_json_config():
42+
"""Parses JSON configuration if provided via environment variables
43+
FMS_MO_CONFIG_JSON_ENV_VAR or FMS_MO_CONFIG_JSON_PATH.
44+
45+
FMS_MO_CONFIG_JSON_ENV_VAR is the base64 encoded JSON.
46+
FMS_MO_CONFIG_JSON_PATH is the path to the JSON config file.
47+
48+
Returns: dict or {}
49+
"""
50+
json_env_var = os.getenv("FMS_MO_CONFIG_JSON_ENV_VAR")
51+
json_path = os.getenv("FMS_MO_CONFIG_JSON_PATH")
52+
53+
# accepts either path to JSON file or encoded string config
54+
# env var takes precedent
55+
job_config_dict = {}
56+
if json_env_var:
57+
job_config_dict = txt_to_obj(json_env_var)
58+
elif json_path:
59+
with open(json_path, "r", encoding="utf-8") as f:
60+
job_config_dict = json.load(f)
61+
62+
return job_config_dict
63+
64+
65+
def txt_to_obj(txt):
66+
"""Given encoded byte string, converts to base64 decoded dict.
67+
68+
Args:
69+
txt: str
70+
Returns: dict[str, Any]
71+
"""
72+
base64_bytes = txt.encode("ascii")
73+
message_bytes = base64.b64decode(base64_bytes)
74+
try:
75+
# If the bytes represent JSON string
76+
return json.loads(message_bytes)
77+
except UnicodeDecodeError:
78+
# Otherwise the bytes are a pickled python dictionary
79+
return pickle.loads(message_bytes)

0 commit comments

Comments
 (0)