Skip to content

Commit 1bbf605

Browse files
authored
Concepts in confog.yaml conditions (#507)
1 parent 70fb5ad commit 1bbf605

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

clarifai/runners/models/model_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def get_model_version_proto(self):
477477
for concept in labels:
478478
concept_proto = json_format.ParseDict(concept, resources_pb2.Concept())
479479
model_version_proto.output_info.data.concepts.append(concept_proto)
480-
else:
480+
elif self.config.get("checkpoints") and HuggingFaceLoader.validate_concept(
481+
self.checkpoint_path):
481482
labels = HuggingFaceLoader.fetch_labels(self.checkpoint_path)
482483
logger.info(f"Found {len(labels)} concepts from the model checkpoints.")
483484
# sort the concepts by id and then update the config file

clarifai/runners/utils/loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,18 @@ def validate_config(checkpoint_path: str):
162162
return os.path.exists(checkpoint_path) and os.path.exists(
163163
os.path.join(checkpoint_path, 'config.json'))
164164

165+
@staticmethod
166+
def validate_concept(checkpoint_path: str):
167+
# check if downloaded concept exists in hf model
168+
config_path = os.path.join(checkpoint_path, 'config.json')
169+
with open(config_path, 'r') as f:
170+
config = json.load(f)
171+
172+
labels = config.get('id2label', None)
173+
if labels:
174+
return True
175+
return False
176+
165177
@staticmethod
166178
def fetch_labels(checkpoint_path: str):
167179
# Fetch labels for classification, detection and segmentation models

0 commit comments

Comments
 (0)