Skip to content

Commit 69ceb3a

Browse files
authored
Generate KerasHub snippets based on tasks from metadata.json (#1118)
This PR updates the `keras-hub` snippets based on the new `metadata.json > tasks` field. This field is now uploaded for all KerasHub models (see keras-team/keras-hub#1997) and contains the list of tasks compatible for a given model. This allows us to generate multiple snippets when relevant. For instance, [keras/stable_diffusion_3.5_large_turbo](https://huggingface.co/keras/stable_diffusion_3.5_large_turbo/blob/main/metadata.json) is compatible with `ImageToImage`, `Inpaint` and `TextToImage` tasks. For this PR to work, we'll need to parse the metadata.json file server-side. This is done in huggingface-internal/moon-landing#11693 (private PR). We can merge these 2 PRs independently. cc @martin-gorner @mattdangerw @SamanehSaadat who coordinated this --- **Note:** I also removed the legacy `keras-nlp` library (only [18 remaining models](https://huggingface.co/models?library=keras-nlp)).
1 parent 6994a90 commit 69ceb3a

File tree

3 files changed

+93
-26
lines changed

3 files changed

+93
-26
lines changed

packages/tasks/src/model-data.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ export interface ModelData {
6666
base_model_name_or_path?: string;
6767
task_type?: string;
6868
};
69+
keras_hub?: {
70+
tasks?: string[];
71+
};
6972
};
7073
/**
7174
* all the model tags

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -394,32 +394,103 @@ model = keras.saving.load_model("hf://${model.id}")
394394
`,
395395
];
396396

397-
export const keras_nlp = (model: ModelData): string[] => [
398-
`# Available backend options are: "jax", "torch", "tensorflow".
399-
import os
400-
os.environ["KERAS_BACKEND"] = "jax"
397+
const _keras_hub_causal_lm = (modelId: string): string => `
398+
import keras_hub
401399
402-
import keras_nlp
400+
# Load CausalLM model (optional: use half precision for inference)
401+
causal_lm = keras_hub.models.CausalLM.from_preset(${modelId}, dtype="bfloat16")
402+
causal_lm.compile(sampler="greedy") # (optional) specify a sampler
403403
404-
tokenizer = keras_nlp.models.Tokenizer.from_preset("hf://${model.id}")
405-
backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
406-
`,
407-
];
404+
# Generate text
405+
causal_lm.generate("Keras: deep learning for", max_length=64)
406+
`;
408407

409-
export const keras_hub = (model: ModelData): string[] => [
410-
`# Available backend options are: "jax", "torch", "tensorflow".
411-
import os
412-
os.environ["KERAS_BACKEND"] = "jax"
408+
const _keras_hub_text_to_image = (modelId: string): string => `
409+
import keras_hub
413410
411+
# Load TextToImage model (optional: use half precision for inference)
412+
text_to_image = keras_hub.models.TextToImage.from_preset(${modelId}, dtype="bfloat16")
413+
414+
# Generate images with a TextToImage model.
415+
text_to_image.generate("Astronaut in a jungle")
416+
`;
417+
418+
const _keras_hub_text_classifier = (modelId: string): string => `
414419
import keras_hub
415420
416-
# Load a task-specific model (*replace CausalLM with your task*)
417-
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
421+
# Load TextClassifier model
422+
text_classifier = keras_hub.models.TextClassifier.from_preset(
423+
${modelId},
424+
num_classes=2,
425+
)
426+
# Fine-tune
427+
text_classifier.fit(x=["Thilling adventure!", "Total snoozefest."], y=[1, 0])
428+
# Classify text
429+
text_classifier.predict(["Not my cup of tea."])
430+
`;
418431

419-
# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
420-
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
421-
`,
422-
];
432+
const _keras_hub_image_classifier = (modelId: string): string => `
433+
import keras_hub
434+
import keras
435+
436+
# Load ImageClassifier model
437+
image_classifier = keras_hub.models.ImageClassifier.from_preset(
438+
${modelId},
439+
num_classes=2,
440+
)
441+
# Fine-tune
442+
image_classifier.fit(
443+
x=keras.random.randint((32, 64, 64, 3), 0, 256),
444+
y=keras.random.randint((32, 1), 0, 2),
445+
)
446+
# Classify image
447+
image_classifier.predict(keras.random.randint((1, 64, 64, 3), 0, 256))
448+
`;
449+
450+
const _keras_hub_tasks_with_example = {
451+
CausalLM: _keras_hub_causal_lm,
452+
TextToImage: _keras_hub_text_to_image,
453+
TextClassifier: _keras_hub_text_classifier,
454+
ImageClassifier: _keras_hub_image_classifier,
455+
};
456+
457+
const _keras_hub_task_without_example = (task: string, modelId: string): string => `
458+
import keras_hub
459+
460+
# Create a ${task} model
461+
task = keras_hub.models.${task}.from_preset(${modelId})
462+
`;
463+
464+
const _keras_hub_generic_backbone = (modelId: string): string => `
465+
import keras_hub
466+
467+
# Create a Backbone model unspecialized for any task
468+
backbone = keras_hub.models.Backbone.from_preset(${modelId})
469+
`;
470+
471+
export const keras_hub = (model: ModelData): string[] => {
472+
const modelId = model.id;
473+
const tasks = model.config?.keras_hub?.tasks ?? [];
474+
475+
const snippets: string[] = [];
476+
477+
// First, generate tasks with examples
478+
for (const [task, snippet] of Object.entries(_keras_hub_tasks_with_example)) {
479+
if (tasks.includes(task)) {
480+
snippets.push(snippet(modelId));
481+
}
482+
}
483+
// Then, add remaining tasks
484+
for (const task in tasks) {
485+
if (!Object.keys(_keras_hub_tasks_with_example).includes(task)) {
486+
snippets.push(_keras_hub_task_without_example(task, modelId));
487+
}
488+
}
489+
// Finally, add generic backbone snippet
490+
snippets.push(_keras_hub_generic_backbone(modelId));
491+
492+
return snippets;
493+
};
423494

424495
export const llama_cpp_python = (model: ModelData): string[] => {
425496
const snippets = [

packages/tasks/src/model-libraries.ts

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,6 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
404404
snippets: snippets.tf_keras,
405405
countDownloads: `path:"saved_model.pb"`,
406406
},
407-
"keras-nlp": {
408-
prettyLabel: "KerasNLP",
409-
repoName: "KerasNLP",
410-
repoUrl: "https://github.com/keras-team/keras-nlp",
411-
docsUrl: "https://keras.io/keras_nlp/",
412-
snippets: snippets.keras_nlp,
413-
},
414407
"keras-hub": {
415408
prettyLabel: "KerasHub",
416409
repoName: "KerasHub",

0 commit comments

Comments
 (0)