@@ -394,32 +394,103 @@ model = keras.saving.load_model("hf://${model.id}")
394394` ,
395395] ;
396396
397- export const keras_nlp = ( model : ModelData ) : string [ ] => [
398- `# Available backend options are: "jax", "torch", "tensorflow".
399- import os
400- os.environ["KERAS_BACKEND"] = "jax"
397+ const _keras_hub_causal_lm = ( modelId : string ) : string => `
398+ import keras_hub
401399
402- import keras_nlp
400+ # Load CausalLM model (optional: use half precision for inference)
401+ causal_lm = keras_hub.models.CausalLM.from_preset(${ modelId } , dtype="bfloat16")
402+ causal_lm.compile(sampler="greedy") # (optional) specify a sampler
403403
404- tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${ model . id } ")
405- backbone = keras_nlp.models.Backbone.from_preset("hf://${ model . id } ")
406- ` ,
407- ] ;
404+ # Generate text
405+ causal_lm.generate("Keras: deep learning for", max_length=64)
406+ ` ;
408407
409- export const keras_hub = ( model : ModelData ) : string [ ] => [
410- `# Available backend options are: "jax", "torch", "tensorflow".
411- import os
412- os.environ["KERAS_BACKEND"] = "jax"
408+ const _keras_hub_text_to_image = ( modelId : string ) : string => `
409+ import keras_hub
413410
411+ # Load TextToImage model (optional: use half precision for inference)
412+ text_to_image = keras_hub.models.TextToImage.from_preset(${ modelId } , dtype="bfloat16")
413+
414+ # Generate images with a TextToImage model.
415+ text_to_image.generate("Astronaut in a jungle")
416+ ` ;
417+
418+ const _keras_hub_text_classifier = ( modelId : string ) : string => `
414419import keras_hub
415420
416- # Load a task-specific model (*replace CausalLM with your task*)
417- model = keras_hub.models.CausalLM.from_preset("hf://${ model . id } ", dtype="bfloat16")
421+ # Load TextClassifier model
422+ text_classifier = keras_hub.models.TextClassifier.from_preset(
423+ ${ modelId } ,
424+ num_classes=2,
425+ )
426+ # Fine-tune
427+ text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
428+ # Classify text
429+ text_classifier.predict(["Not my cup of tea."])
430+ ` ;
418431
419- # Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
420- # full list here: https://keras.io/api/keras_hub/models/#api-documentation
421- ` ,
422- ] ;
432+ const _keras_hub_image_classifier = ( modelId : string ) : string => `
433+ import keras_hub
434+ import keras
435+
436+ # Load ImageClassifier model
437+ image_classifier = keras_hub.models.ImageClassifier.from_preset(
438+ ${ modelId } ,
439+ num_classes=2,
440+ )
441+ # Fine-tune
442+ image_classifier.fit(
443+ x=keras.random.randint((32, 64, 64, 3), 0, 256),
444+ y=keras.random.randint((32, 1), 0, 2),
445+ )
446+ # Classify image
447+ image_classifier.predict(keras.random.randint((1, 64, 64, 3), 0, 256))
448+ ` ;
449+
450+ const _keras_hub_tasks_with_example = {
451+ CausalLM : _keras_hub_causal_lm ,
452+ TextToImage : _keras_hub_text_to_image ,
453+ TextClassifier : _keras_hub_text_classifier ,
454+ ImageClassifier : _keras_hub_image_classifier ,
455+ } ;
456+
457+ const _keras_hub_task_without_example = ( task : string , modelId : string ) : string => `
458+ import keras_hub
459+
460+ # Create a ${ task } model
461+ task = keras_hub.models.${ task } .from_preset(${ modelId } )
462+ ` ;
463+
464+ const _keras_hub_generic_backbone = ( modelId : string ) : string => `
465+ import keras_hub
466+
467+ # Create a Backbone model unspecialized for any task
468+ backbone = keras_hub.models.Backbone.from_preset(${ modelId } )
469+ ` ;
470+
471+ export const keras_hub = ( model : ModelData ) : string [ ] => {
472+ const modelId = model . id ;
473+ const tasks = model . config ?. keras_hub ?. tasks ?? [ ] ;
474+
475+ const snippets : string [ ] = [ ] ;
476+
477+ // First, generate tasks with examples
478+ for ( const [ task , snippet ] of Object . entries ( _keras_hub_tasks_with_example ) ) {
479+ if ( tasks . includes ( task ) ) {
480+ snippets . push ( snippet ( modelId ) ) ;
481+ }
482+ }
483+ // Then, add remaining tasks
484+ for ( const task in tasks ) {
485+ if ( ! Object . keys ( _keras_hub_tasks_with_example ) . includes ( task ) ) {
486+ snippets . push ( _keras_hub_task_without_example ( task , modelId ) ) ;
487+ }
488+ }
489+ // Finally, add generic backbone snippet
490+ snippets . push ( _keras_hub_generic_backbone ( modelId ) ) ;
491+
492+ return snippets ;
493+ } ;
423494
424495export const llama_cpp_python = ( model : ModelData ) : string [ ] => {
425496 const snippets = [
0 commit comments