@@ -41,7 +41,7 @@ class Metadata:
4141    base_models : Optional [list [dict ]] =  None 
4242    tags : Optional [list [str ]] =  None 
4343    languages : Optional [list [str ]] =  None 
44-     datasets : Optional [list [str ]] =  None 
44+     datasets : Optional [list [dict ]] =  None 
4545
4646    @staticmethod  
4747    def  load (metadata_override_path : Optional [Path ] =  None , model_path : Optional [Path ] =  None , model_name : Optional [str ] =  None , total_params : int  =  0 ) ->  Metadata :
@@ -91,9 +91,11 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
9191        # Base Models is received here as an array of models 
9292        metadata .base_models      =  metadata_override .get ("general.base_models" ,        metadata .base_models )
9393
94+         # Datasets is received here as an array of datasets 
95+         metadata .datasets         =  metadata_override .get ("general.datasets" ,           metadata .datasets )
96+ 
9497        metadata .tags             =  metadata_override .get (Keys .General .TAGS ,            metadata .tags )
9598        metadata .languages        =  metadata_override .get (Keys .General .LANGUAGES ,       metadata .languages )
96-         metadata .datasets         =  metadata_override .get (Keys .General .DATASETS ,        metadata .datasets )
9799
98100        # Direct Metadata Override (via direct cli argument) 
99101        if  model_name  is  not   None :
@@ -346,12 +348,12 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
346348            use_model_card_metadata ("author" , "model_creator" )
347349            use_model_card_metadata ("basename" , "model_type" )
348350
349-             if  "base_model"  in  model_card :
351+             if  "base_model"  in  model_card   or   "base_models"   in   model_card   or   "base_model_sources"   in   model_card :
350352                # This represents the parent models that this is based on 
351353                # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) 
352354                # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md 
353355                metadata_base_models  =  []
354-                 base_model_value  =  model_card .get ("base_model" , None )
356+                 base_model_value  =  model_card .get ("base_model" , model_card . get ( "base_models" ,  model_card . get ( "base_model_sources" ,  None )) )
355357
356358                if  base_model_value  is  not   None :
357359                    if  isinstance (base_model_value , str ):
@@ -364,18 +366,106 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
364366
365367                for  model_id  in  metadata_base_models :
366368                    # NOTE: model size of base model is assumed to be similar to the size of the current model 
367-                     model_full_name_component , org_component , basename , finetune , version , size_label  =  Metadata .get_model_id_components (model_id , total_params )
368369                    base_model  =  {}
369-                     if  model_full_name_component  is  not   None :
370-                         base_model ["name" ] =  Metadata .id_to_title (model_full_name_component )
371-                     if  org_component  is  not   None :
372-                         base_model ["organization" ] =  Metadata .id_to_title (org_component )
373-                     if  version  is  not   None :
374-                         base_model ["version" ] =  version 
375-                     if  org_component  is  not   None  and  model_full_name_component  is  not   None :
376-                         base_model ["repo_url" ] =  f"https://huggingface.co/{ org_component }  /{ model_full_name_component }  " 
370+                     if  isinstance (model_id , str ):
371+                         if  model_id .startswith ("http://" ) or  model_id .startswith ("https://" ) or  model_id .startswith ("ssh://" ):
372+                             base_model ["repo_url" ] =  model_id 
373+ 
374+                             # Check if Hugging Face ID is present in URL 
375+                             if  "huggingface.co"  in  model_id :
376+                                 match  =  re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , model_id )
377+                                 if  match :
378+                                     model_id_component  =  match .group (1 )
379+                                     model_full_name_component , org_component , basename , finetune , version , size_label  =  Metadata .get_model_id_components (model_id_component , total_params )
380+ 
381+                                     # Populate model dictionary with extracted components 
382+                                     if  model_full_name_component  is  not   None :
383+                                         base_model ["name" ] =  Metadata .id_to_title (model_full_name_component )
384+                                     if  org_component  is  not   None :
385+                                         base_model ["organization" ] =  Metadata .id_to_title (org_component )
386+                                     if  version  is  not   None :
387+                                         base_model ["version" ] =  version 
388+ 
389+                         else :
390+                             # Likely a Hugging Face ID 
391+                             model_full_name_component , org_component , basename , finetune , version , size_label  =  Metadata .get_model_id_components (model_id , total_params )
392+ 
393+                             # Populate model dictionary with extracted components 
394+                             if  model_full_name_component  is  not   None :
395+                                 base_model ["name" ] =  Metadata .id_to_title (model_full_name_component )
396+                             if  org_component  is  not   None :
397+                                 base_model ["organization" ] =  Metadata .id_to_title (org_component )
398+                             if  version  is  not   None :
399+                                 base_model ["version" ] =  version 
400+                             if  org_component  is  not   None  and  model_full_name_component  is  not   None :
401+                                 base_model ["repo_url" ] =  f"https://huggingface.co/{ org_component }  /{ model_full_name_component }  " 
402+ 
403+                     elif  isinstance (model_id , dict ):
404+                         base_model  =  model_id 
405+ 
406+                     else :
407+                         logger .error (f"base model entry '{ str (model_id )}  ' not in a known format" )
408+ 
377409                    metadata .base_models .append (base_model )
378410
411+             if  "datasets"  in  model_card  or  "dataset"  in  model_card  or  "dataset_sources"  in  model_card :
412+                 # This represents the datasets that this was trained from 
413+                 metadata_datasets  =  []
414+                 dataset_value  =  model_card .get ("datasets" , model_card .get ("dataset" , model_card .get ("dataset_sources" , None )))
415+ 
416+                 if  dataset_value  is  not   None :
417+                     if  isinstance (dataset_value , str ):
418+                         metadata_datasets .append (dataset_value )
419+                     elif  isinstance (dataset_value , list ):
420+                         metadata_datasets .extend (dataset_value )
421+ 
422+                 if  metadata .datasets  is  None :
423+                     metadata .datasets  =  []
424+ 
425+                 for  dataset_id  in  metadata_datasets :
426+                     # NOTE: model size of base model is assumed to be similar to the size of the current model 
427+                     dataset  =  {}
428+                     if  isinstance (dataset_id , str ):
429+                         if  dataset_id .startswith (("http://" , "https://" , "ssh://" )):
430+                             dataset ["repo_url" ] =  dataset_id 
431+ 
432+                             # Check if Hugging Face ID is present in URL 
433+                             if  "huggingface.co"  in  dataset_id :
434+                                 match  =  re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , dataset_id )
435+                                 if  match :
436+                                     dataset_id_component  =  match .group (1 )
437+                                     dataset_name_component , org_component , basename , finetune , version , size_label  =  Metadata .get_model_id_components (dataset_id_component , total_params )
438+ 
439+                                     # Populate dataset dictionary with extracted components 
440+                                     if  dataset_name_component  is  not   None :
441+                                         dataset ["name" ] =  Metadata .id_to_title (dataset_name_component )
442+                                     if  org_component  is  not   None :
443+                                         dataset ["organization" ] =  Metadata .id_to_title (org_component )
444+                                     if  version  is  not   None :
445+                                         dataset ["version" ] =  version 
446+ 
447+                         else :
448+                             # Likely a Hugging Face ID 
449+                             dataset_name_component , org_component , basename , finetune , version , size_label  =  Metadata .get_model_id_components (dataset_id , total_params )
450+ 
451+                             # Populate dataset dictionary with extracted components 
452+                             if  dataset_name_component  is  not   None :
453+                                 dataset ["name" ] =  Metadata .id_to_title (dataset_name_component )
454+                             if  org_component  is  not   None :
455+                                 dataset ["organization" ] =  Metadata .id_to_title (org_component )
456+                             if  version  is  not   None :
457+                                 dataset ["version" ] =  version 
458+                             if  org_component  is  not   None  and  dataset_name_component  is  not   None :
459+                                 dataset ["repo_url" ] =  f"https://huggingface.co/{ org_component }  /{ dataset_name_component }  " 
460+ 
461+                     elif  isinstance (dataset_id , dict ):
462+                         dataset  =  dataset_id 
463+ 
464+                     else :
465+                         logger .error (f"dataset entry '{ str (dataset_id )}  ' not in a known format" )
466+ 
467+                     metadata .datasets .append (dataset )
468+ 
379469            use_model_card_metadata ("license" , "license" )
380470            use_model_card_metadata ("license_name" , "license_name" )
381471            use_model_card_metadata ("license_link" , "license_link" )
@@ -386,9 +476,6 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
386476            use_array_model_card_metadata ("languages" , "languages" )
387477            use_array_model_card_metadata ("languages" , "language" )
388478
389-             use_array_model_card_metadata ("datasets" , "datasets" )
390-             use_array_model_card_metadata ("datasets" , "dataset" )
391- 
392479        # Hugging Face Parameter Heuristics 
393480        #################################### 
394481
@@ -493,6 +580,8 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
493580                    gguf_writer .add_base_model_version (key , base_model_entry ["version" ])
494581                if  "organization"  in  base_model_entry :
495582                    gguf_writer .add_base_model_organization (key , base_model_entry ["organization" ])
583+                 if  "description"  in  base_model_entry :
584+                     gguf_writer .add_base_model_description (key , base_model_entry ["description" ])
496585                if  "url"  in  base_model_entry :
497586                    gguf_writer .add_base_model_url (key , base_model_entry ["url" ])
498587                if  "doi"  in  base_model_entry :
@@ -502,9 +591,29 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
502591                if  "repo_url"  in  base_model_entry :
503592                    gguf_writer .add_base_model_repo_url (key , base_model_entry ["repo_url" ])
504593
594+         if  self .datasets  is  not   None :
595+             gguf_writer .add_dataset_count (len (self .datasets ))
596+             for  key , dataset_entry  in  enumerate (self .datasets ):
597+                 if  "name"  in  dataset_entry :
598+                     gguf_writer .add_dataset_name (key , dataset_entry ["name" ])
599+                 if  "author"  in  dataset_entry :
600+                     gguf_writer .add_dataset_author (key , dataset_entry ["author" ])
601+                 if  "version"  in  dataset_entry :
602+                     gguf_writer .add_dataset_version (key , dataset_entry ["version" ])
603+                 if  "organization"  in  dataset_entry :
604+                     gguf_writer .add_dataset_organization (key , dataset_entry ["organization" ])
605+                 if  "description"  in  dataset_entry :
606+                     gguf_writer .add_dataset_description (key , dataset_entry ["description" ])
607+                 if  "url"  in  dataset_entry :
608+                     gguf_writer .add_dataset_url (key , dataset_entry ["url" ])
609+                 if  "doi"  in  dataset_entry :
610+                     gguf_writer .add_dataset_doi (key , dataset_entry ["doi" ])
611+                 if  "uuid"  in  dataset_entry :
612+                     gguf_writer .add_dataset_uuid (key , dataset_entry ["uuid" ])
613+                 if  "repo_url"  in  dataset_entry :
614+                     gguf_writer .add_dataset_repo_url (key , dataset_entry ["repo_url" ])
615+ 
505616        if  self .tags  is  not   None :
506617            gguf_writer .add_tags (self .tags )
507618        if  self .languages  is  not   None :
508619            gguf_writer .add_languages (self .languages )
509-         if  self .datasets  is  not   None :
510-             gguf_writer .add_datasets (self .datasets )
0 commit comments