@@ -403,20 +403,69 @@ backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
403403` ,
404404] ;
405405
406- export const keras_hub = ( model : ModelData ) : string [ ] => [
407- `# Available backend options are: "jax", "torch", "tensorflow".
406+ export function keras_hub ( model : ModelData ) : string [ ] {
407+ // If the model has a task.json config, then the base Task class is known
408+ let class_name = model . config ?. keras_hub_task_json ?. class_name ;
409+ if ( ! class_name )
410+ // If only a config.json is present, the base class will be a "backbone"
411+ class_name = model . config ?. keras_hub_config_json ?. class_name ;
412+
413+ // Fallback heuristic: until task.json is populated in all keras-hub models,
414+ // make a best effort, for text-generation models only, to disply
415+ // a "XXXCausalLM" base class instead of XXXBackbone.
416+ if ( model . pipeline_tag == "text-generation" && class_name ?. endsWith ( "Backbone" ) )
417+ class_name = class_name . replace ( "Backbone" , "CausalLM" ) ;
418+
419+ // optional generation snippets
420+ const optional_snippets = [
421+ [ "text-generation" , 'model.generate("Keras: deep learning for", max_length=64)' ] ,
422+ [
423+ "image-text-to-text" ,
424+ `output = model.generate(
425+ inputs={
426+ "images": image,
427+ "prompts": prompt,
428+ }
429+ )` ,
430+ ] ,
431+ ] ;
432+ const selected_snippet_row = optional_snippets . filter ( ( cols ) => cols [ 0 ] == model . pipeline_tag ) ;
433+ const optional_snippet = selected_snippet_row . length == 0 ? "" : selected_snippet_row [ 0 ] [ 1 ] ;
434+
435+ // de-duplicate possible alt classes
436+ // from task.json
437+ const alt_class_names = new Set ( model . config ?. keras_hub_task_json ?. alt_class_names ) ;
438+ if ( class_name ) alt_class_names . delete ( class_name ) ;
439+ // and from tokenizer.json
440+ if ( model . config ?. keras_hub_tokenizer_json ?. class_name )
441+ alt_class_names . add ( model . config ?. keras_hub_tokenizer_json ?. class_name ) ;
442+ // generate possible alternative class.from_preset() calls.
443+ let alt_model_component_snippets = undefined ;
444+ if ( alt_class_names . size > 0 ) {
445+ const alt_model_component_snippet_lines = Array . from ( alt_class_names ) . map (
446+ ( k ) => `model = keras_hub.models.${ k } .from_preset("hf://${ model . id } ")`
447+ ) ;
448+ alt_model_component_snippets =
449+ "# Individual model components can also be loaded from this preset:\n" +
450+ alt_model_component_snippet_lines . join ( "\n" ) ;
451+ }
452+
453+ const main_snippet = ` # Available backend options are: "jax", "torch", "tensorflow".
408454import os
409455os.environ["KERAS_BACKEND"] = "jax"
410456
411457import keras_hub
412458
413- # Load a task-specific model (*replace CausalLM with your task* )
414- model = keras_hub.models.CausalLM.from_preset("hf:// ${ model . id } ", dtype="bfloat16")
459+ model = keras_hub.models. ${ class_name } .from_preset("hf:// ${ model . id } " )
460+ ${ optional_snippet }
415461
416- # Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
417- # full list here: https://keras.io/api/keras_hub/models/#api-documentation
418- ` ,
419- ] ;
462+ # All Keras models support: model(data), model.compile, model.fit, model.predict, model.evaluate.
463+ # More info on this model: https://keras.io/search.html?query=${ class_name } %20keras_hub
464+ ` ;
465+ const snippets = [ main_snippet ] ;
466+ if ( alt_model_component_snippets ) snippets . push ( alt_model_component_snippets ) ;
467+ return snippets ;
468+ }
420469
421470export const llama_cpp_python = ( model : ModelData ) : string [ ] => [
422471 `from llama_cpp import Llama
0 commit comments