@@ -41,7 +41,7 @@ class Metadata:
4141 base_models : Optional [list [dict ]] = None
4242 tags : Optional [list [str ]] = None
4343 languages : Optional [list [str ]] = None
44- datasets : Optional [list [str ]] = None
44+ datasets : Optional [list [dict ]] = None
4545
4646 @staticmethod
4747 def load (metadata_override_path : Optional [Path ] = None , model_path : Optional [Path ] = None , model_name : Optional [str ] = None , total_params : int = 0 ) -> Metadata :
@@ -91,9 +91,11 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
9191 # Base Models is received here as an array of models
9292 metadata .base_models = metadata_override .get ("general.base_models" , metadata .base_models )
9393
94+ # Datasets is received here as an array of datasets
95+ metadata .datasets = metadata_override .get ("general.datasets" , metadata .datasets )
96+
9497 metadata .tags = metadata_override .get (Keys .General .TAGS , metadata .tags )
9598 metadata .languages = metadata_override .get (Keys .General .LANGUAGES , metadata .languages )
96- metadata .datasets = metadata_override .get (Keys .General .DATASETS , metadata .datasets )
9799
98100 # Direct Metadata Override (via direct cli argument)
99101 if model_name is not None :
@@ -346,12 +348,12 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
346348 use_model_card_metadata ("author" , "model_creator" )
347349 use_model_card_metadata ("basename" , "model_type" )
348350
349- if "base_model" in model_card :
351+ if "base_model" in model_card or "base_models" in model_card :
350352 # This represents the parent models that this is based on
351353 # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
352354 # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
353355 metadata_base_models = []
354- base_model_value = model_card .get ("base_model" , None )
356+ base_model_value = model_card .get ("base_model" , model_card . get ( "base_models" , None ) )
355357
356358 if base_model_value is not None :
357359 if isinstance (base_model_value , str ):
@@ -364,18 +366,98 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
364366
365367 for model_id in metadata_base_models :
366368 # NOTE: model size of base model is assumed to be similar to the size of the current model
367- model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id , total_params )
368369 base_model = {}
369- if model_full_name_component is not None :
370- base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
371- if org_component is not None :
372- base_model ["organization" ] = Metadata .id_to_title (org_component )
373- if version is not None :
374- base_model ["version" ] = version
375- if org_component is not None and model_full_name_component is not None :
376- base_model ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ model_full_name_component } "
370+ if isinstance (model_id , str ):
371+ if model_id .startswith ("http://" ) or model_id .startswith ("https://" ) or model_id .startswith ("ssh://" ):
372+ base_model ["repo_url" ] = model_id
373+
374+ # Check if Hugging Face ID is present in URL
375+ if "huggingface.co" in model_id :
376+ match = re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , model_id )
377+ if match :
378+ model_id_component = match .group (1 )
379+ model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id_component , total_params )
380+
381+ # Populate model dictionary with extracted components
382+ if model_full_name_component is not None :
383+ base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
384+ if org_component is not None :
385+ base_model ["organization" ] = Metadata .id_to_title (org_component )
386+ if version is not None :
387+ base_model ["version" ] = version
388+
389+ else :
390+ # Likely a Hugging Face ID
391+ model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id , total_params )
392+
393+ # Populate model dictionary with extracted components
394+ if model_full_name_component is not None :
395+ base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
396+ if org_component is not None :
397+ base_model ["organization" ] = Metadata .id_to_title (org_component )
398+ if version is not None :
399+ base_model ["version" ] = version
400+ if org_component is not None and model_full_name_component is not None :
401+ base_model ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ model_full_name_component } "
402+
403+ else :
404+ logger .error (f"base model entry '{ str (model_id )} ' not in a known format" )
377405 metadata .base_models .append (base_model )
378406
407+ if "datasets" in model_card or "dataset" in model_card :
408+ # This represents the datasets that this was trained from
409+ metadata_datasets = []
410+ dataset_value = model_card .get ("datasets" , model_card .get ("dataset" , None ))
411+
412+ if dataset_value is not None :
413+ if isinstance (dataset_value , str ):
414+ metadata_datasets .append (dataset_value )
415+ elif isinstance (dataset_value , list ):
416+ metadata_datasets .extend (dataset_value )
417+
418+ if metadata .datasets is None :
419+ metadata .datasets = []
420+
421+ for dataset_id in metadata_datasets :
422+ # NOTE: model size of base model is assumed to be similar to the size of the current model
423+ dataset = {}
424+ if isinstance (dataset_id , str ):
425+ if dataset_id .startswith (("http://" , "https://" , "ssh://" )):
426+ dataset ["repo_url" ] = dataset_id
427+
428+ # Check if Hugging Face ID is present in URL
429+ if "huggingface.co" in dataset_id :
430+ match = re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , dataset_id )
431+ if match :
432+ dataset_id_component = match .group (1 )
433+ dataset_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (dataset_id_component , total_params )
434+
435+ # Populate dataset dictionary with extracted components
436+ if dataset_name_component is not None :
437+ dataset ["name" ] = Metadata .id_to_title (dataset_name_component )
438+ if org_component is not None :
439+ dataset ["organization" ] = Metadata .id_to_title (org_component )
440+ if version is not None :
441+ dataset ["version" ] = version
442+
443+ else :
444+ # Likely a Hugging Face ID
445+ dataset_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (dataset_id , total_params )
446+
447+ # Populate dataset dictionary with extracted components
448+ if dataset_name_component is not None :
449+ dataset ["name" ] = Metadata .id_to_title (dataset_name_component )
450+ if org_component is not None :
451+ dataset ["organization" ] = Metadata .id_to_title (org_component )
452+ if version is not None :
453+ dataset ["version" ] = version
454+ if org_component is not None and dataset_name_component is not None :
455+ dataset ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ dataset_name_component } "
456+
457+ else :
458+ logger .error (f"dataset entry '{ str (dataset_id )} ' not in a known format" )
459+ metadata .datasets .append (dataset )
460+
379461 use_model_card_metadata ("license" , "license" )
380462 use_model_card_metadata ("license_name" , "license_name" )
381463 use_model_card_metadata ("license_link" , "license_link" )
@@ -386,9 +468,6 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
386468 use_array_model_card_metadata ("languages" , "languages" )
387469 use_array_model_card_metadata ("languages" , "language" )
388470
389- use_array_model_card_metadata ("datasets" , "datasets" )
390- use_array_model_card_metadata ("datasets" , "dataset" )
391-
392471 # Hugging Face Parameter Heuristics
393472 ####################################
394473
@@ -493,6 +572,8 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
493572 gguf_writer .add_base_model_version (key , base_model_entry ["version" ])
494573 if "organization" in base_model_entry :
495574 gguf_writer .add_base_model_organization (key , base_model_entry ["organization" ])
575+ if "description" in base_model_entry :
576+ gguf_writer .add_base_model_description (key , base_model_entry ["description" ])
496577 if "url" in base_model_entry :
497578 gguf_writer .add_base_model_url (key , base_model_entry ["url" ])
498579 if "doi" in base_model_entry :
@@ -502,9 +583,29 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
502583 if "repo_url" in base_model_entry :
503584 gguf_writer .add_base_model_repo_url (key , base_model_entry ["repo_url" ])
504585
586+ if self .datasets is not None :
587+ gguf_writer .add_dataset_count (len (self .datasets ))
588+ for key , dataset_entry in enumerate (self .datasets ):
589+ if "name" in dataset_entry :
590+ gguf_writer .add_dataset_name (key , dataset_entry ["name" ])
591+ if "author" in dataset_entry :
592+ gguf_writer .add_dataset_author (key , dataset_entry ["author" ])
593+ if "version" in dataset_entry :
594+ gguf_writer .add_dataset_version (key , dataset_entry ["version" ])
595+ if "organization" in dataset_entry :
596+ gguf_writer .add_dataset_organization (key , dataset_entry ["organization" ])
597+ if "description" in dataset_entry :
598+ gguf_writer .add_dataset_description (key , dataset_entry ["description" ])
599+ if "url" in dataset_entry :
600+ gguf_writer .add_dataset_url (key , dataset_entry ["url" ])
601+ if "doi" in dataset_entry :
602+ gguf_writer .add_dataset_doi (key , dataset_entry ["doi" ])
603+ if "uuid" in dataset_entry :
604+ gguf_writer .add_dataset_uuid (key , dataset_entry ["uuid" ])
605+ if "repo_url" in dataset_entry :
606+ gguf_writer .add_dataset_repo_url (key , dataset_entry ["repo_url" ])
607+
505608 if self .tags is not None :
506609 gguf_writer .add_tags (self .tags )
507610 if self .languages is not None :
508611 gguf_writer .add_languages (self .languages )
509- if self .datasets is not None :
510- gguf_writer .add_datasets (self .datasets )
0 commit comments