Skip to content

Commit 0020154

Browse files
committed
better snippets for KerasHub models
1 parent 4606df3 commit 0020154

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

packages/tasks/src/model-data.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ export interface ModelData {
6666
base_model_name_or_path?: string;
6767
task_type?: string;
6868
};
69+
keras_hub_task_json?: {
70+
class_name: string;
71+
alt_class_names?: string[];
72+
};
73+
keras_hub_config_json?: {
74+
class_name: string;
75+
};
76+
keras_hub_tokenizer_json?: {
77+
class_name: string;
78+
};
6979
};
7080
/**
7181
* all the model tags

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,20 +403,69 @@ backbone = keras_nlp.models.Backbone.from_preset("hf://${model.id}")
403403
`,
404404
];
405405

406-
export const keras_hub = (model: ModelData): string[] => [
407-
`# Available backend options are: "jax", "torch", "tensorflow".
406+
export function keras_hub(model: ModelData): string[] {
407+
// If the model has a task.json config, then the base Task class is known
408+
let class_name = model.config?.keras_hub_task_json?.class_name;
409+
if (!class_name)
410+
// If only a config.json is present, the base class will be a "backbone"
411+
class_name = model.config?.keras_hub_config_json?.class_name;
412+
413+
// Fallback heuristic: until task.json is populated in all keras-hub models,
414+
// make a best effort, for text-generation models only, to disply
415+
// a "XXXCausalLM" base class instead of XXXBackbone.
416+
if (model.pipeline_tag == "text-generation" && class_name?.endsWith("Backbone"))
417+
class_name = class_name.replace("Backbone", "CausalLM");
418+
419+
// optional generation snippets
420+
const optional_snippets = [
421+
["text-generation", 'model.generate("Keras: deep learning for", max_length=64)'],
422+
[
423+
"image-text-to-text",
424+
`output = model.generate(
425+
inputs={
426+
"images": image,
427+
"prompts": prompt,
428+
}
429+
)`,
430+
],
431+
];
432+
const selected_snippet_row = optional_snippets.filter((cols) => cols[0] == model.pipeline_tag);
433+
const optional_snippet = selected_snippet_row.length == 0 ? "" : selected_snippet_row[0][1];
434+
435+
// de-duplicate possible alt classes
436+
// from task.json
437+
const alt_class_names = new Set(model.config?.keras_hub_task_json?.alt_class_names);
438+
if (class_name) alt_class_names.delete(class_name);
439+
// and from tokenizer.json
440+
if (model.config?.keras_hub_tokenizer_json?.class_name)
441+
alt_class_names.add(model.config?.keras_hub_tokenizer_json?.class_name);
442+
// generate possible alternative class.from_preset() calls.
443+
let alt_model_component_snippets = undefined;
444+
if (alt_class_names.size > 0) {
445+
const alt_model_component_snippet_lines = Array.from(alt_class_names).map(
446+
(k) => `model = keras_hub.models.${k}.from_preset("hf://${model.id}")`
447+
);
448+
alt_model_component_snippets =
449+
"# Individual model components can also be loaded from this preset:\n" +
450+
alt_model_component_snippet_lines.join("\n");
451+
}
452+
453+
const main_snippet = ` # Available backend options are: "jax", "torch", "tensorflow".
408454
import os
409455
os.environ["KERAS_BACKEND"] = "jax"
410456
411457
import keras_hub
412458
413-
# Load a task-specific model (*replace CausalLM with your task*)
414-
model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat16")
459+
model = keras_hub.models.${class_name}.from_preset("hf://${model.id}")
460+
${optional_snippet}
415461
416-
# Possible tasks are CausalLM, TextToImage, ImageClassifier, ...
417-
# full list here: https://keras.io/api/keras_hub/models/#api-documentation
418-
`,
419-
];
462+
# All Keras models support: model(data), model.compile, model.fit, model.predict, model.evaluate.
463+
# More info on this model: https://keras.io/search.html?query=${class_name}%20keras_hub
464+
`;
465+
const snippets = [main_snippet];
466+
if (alt_model_component_snippets) snippets.push(alt_model_component_snippets);
467+
return snippets;
468+
}
420469

421470
export const llama_cpp_python = (model: ModelData): string[] => [
422471
`from llama_cpp import Llama

0 commit comments

Comments
 (0)