diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py index dcb95da3be..701d43b8e0 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py @@ -9,6 +9,7 @@ import dataclasses import json +import logging from typing import Any, Optional import torch @@ -78,7 +79,20 @@ def from_dict(cls, data: dict[str, Any]): @classmethod # pyre-ignore [3] def from_json(cls, data: str): - return cls.from_dict(json.loads(data)) + raw = json.loads(data) + allowed = {f.name for f in dataclasses.fields(cls)} + filtered = {k: v for k, v in raw.items() if k in allowed} + missing = allowed - set(filtered.keys()) + extra = set(raw.keys()) - allowed + if missing: + logging.warning( + f"TBEDataConfig.from_json: Missing expected fields not loaded: {sorted(missing)}" + ) + if extra: + logging.info( + f"TBEDataConfig.from_json: Ignored unknown fields from input: {sorted(extra)}" + ) + return cls.from_dict(filtered) def dict(self) -> dict[str, Any]: tmp = dataclasses.asdict(self)