diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index ebafd6a521..95fa1dace3 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -51,7 +51,21 @@ def __init__( self.logger = get_module_logger("TRA") self.logger.info("TRA Model...") - self.model = eval(model_type)(**model_config).to(device) + # Secure model registry - whitelist of allowed model classes + # This prevents arbitrary code execution while allowing dynamic model selection + model_registry = { + "LSTM": LSTM, + "Transformer": Transformer, + } + + if model_type not in model_registry: + raise ValueError( + f"Unknown model_type: '{model_type}'. " + f"Supported types: {list(model_registry.keys())}" + ) + + model_class = model_registry[model_type] + self.model = model_class(**model_config).to(device) if model_init_state: self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"]) if freeze_model: