1616import argparse
1717import json
1818import os
19+ from typing import List , Dict , Any
20+ from dataclasses import dataclass
21+ from omegaconf import OmegaConf
1922
2023EXPECTED_HYPERPARAMETERS = {
2124 "integer" : 1 ,
2629 "dict" : {
2730 "string" : "value" ,
2831 "integer" : 3 ,
32+ "float" : 3.14 ,
2933 "list" : [1 , 2 , 3 ],
3034 "dict" : {"key" : "value" },
3135 "boolean" : True ,
@@ -117,7 +121,7 @@ def main():
117121 assert isinstance (params ["dict" ], dict )
118122
119123 params = json .loads (os .environ ["SM_TRAINING_ENV" ])["hyperparameters" ]
120- print (params )
124+ print (f"SM_TRAINING_ENV -> hyperparameters: { params } " )
121125 assert params ["string" ] == EXPECTED_HYPERPARAMETERS ["string" ]
122126 assert params ["integer" ] == EXPECTED_HYPERPARAMETERS ["integer" ]
123127 assert params ["boolean" ] == EXPECTED_HYPERPARAMETERS ["boolean" ]
@@ -132,9 +136,96 @@ def main():
132136 assert isinstance (params ["float" ], float )
133137 assert isinstance (params ["list" ], list )
134138 assert isinstance (params ["dict" ], dict )
135- print (f"SM_TRAINING_ENV -> hyperparameters: { params } " )
136139
137- print ("Test passed." )
140+ # Local JSON - DictConfig OmegaConf
141+ params = OmegaConf .load ("hyperparameters.json" )
142+
143+ print (f"Local hyperparameters.json: { params } " )
144+ assert params .string == EXPECTED_HYPERPARAMETERS ["string" ]
145+ assert params .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
146+ assert params .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
147+ assert params .float == EXPECTED_HYPERPARAMETERS ["float" ]
148+ assert params .list == EXPECTED_HYPERPARAMETERS ["list" ]
149+ assert params .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
150+ assert params .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
151+ assert params .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
152+ assert params .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
153+ assert params .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
154+ assert params .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
155+ assert params .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
156+
157+ @dataclass
158+ class DictConfig :
159+ string : str
160+ integer : int
161+ boolean : bool
162+ float : float
163+ list : List [int ]
164+ dict : Dict [str , Any ]
165+
166+ @dataclass
167+ class HPConfig :
168+ string : str
169+ integer : int
170+ boolean : bool
171+ float : float
172+ list : List [int ]
173+ dict : DictConfig
174+
175+ # Local JSON - Structured OmegaConf
176+ hp_config : HPConfig = OmegaConf .merge (
177+ OmegaConf .structured (HPConfig ), OmegaConf .load ("hyperparameters.json" )
178+ )
179+ print (f"Local hyperparameters.json - Structured: { hp_config } " )
180+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
181+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
182+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
183+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
184+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
185+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
186+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
187+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
188+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
189+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
190+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
191+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
192+
193+ # Local YAML - Structured OmegaConf
194+ hp_config : HPConfig = OmegaConf .merge (
195+ OmegaConf .structured (HPConfig ), OmegaConf .load ("hyperparameters.yaml" )
196+ )
197+ print (f"Local hyperparameters.yaml - Structured: { hp_config } " )
198+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
199+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
200+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
201+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
202+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
203+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
204+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
205+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
206+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
207+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
208+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
209+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
210+ print (f"hyperparameters.yaml -> hyperparameters: { hp_config } " )
211+
212+ # HP Dict - Structured OmegaConf
213+ hp_dict = json .loads (os .environ ["SM_HPS" ])
214+ hp_config : HPConfig = OmegaConf .merge (OmegaConf .structured (HPConfig ), OmegaConf .create (hp_dict ))
215+ print (f"SM_HPS - Structured: { hp_config } " )
216+ assert hp_config .string == EXPECTED_HYPERPARAMETERS ["string" ]
217+ assert hp_config .integer == EXPECTED_HYPERPARAMETERS ["integer" ]
218+ assert hp_config .boolean == EXPECTED_HYPERPARAMETERS ["boolean" ]
219+ assert hp_config .float == EXPECTED_HYPERPARAMETERS ["float" ]
220+ assert hp_config .list == EXPECTED_HYPERPARAMETERS ["list" ]
221+ assert hp_config .dict == EXPECTED_HYPERPARAMETERS ["dict" ]
222+ assert hp_config .dict .string == EXPECTED_HYPERPARAMETERS ["dict" ]["string" ]
223+ assert hp_config .dict .integer == EXPECTED_HYPERPARAMETERS ["dict" ]["integer" ]
224+ assert hp_config .dict .boolean == EXPECTED_HYPERPARAMETERS ["dict" ]["boolean" ]
225+ assert hp_config .dict .float == EXPECTED_HYPERPARAMETERS ["dict" ]["float" ]
226+ assert hp_config .dict .list == EXPECTED_HYPERPARAMETERS ["dict" ]["list" ]
227+ assert hp_config .dict .dict == EXPECTED_HYPERPARAMETERS ["dict" ]["dict" ]
228+ print (f"SM_HPS -> hyperparameters: { hp_config } " )
138229
139230
140231if __name__ == "__main__" :
0 commit comments