diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 15e1774fe7..dad22a8277 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -65,7 +65,7 @@ from audioseal import AudioSeal model = AudioSeal.load_generator("${model.id}") # pass a tensor (tensor_wav) of shape (batch, channels, samples) and a sample rate wav, sr = tensor_wav, 16000 - + watermark = model.get_watermark(wav, sr) watermarked_audio = wav + watermark`; @@ -73,7 +73,7 @@ watermarked_audio = wav + watermark`; from audioseal import AudioSeal detector = AudioSeal.load_detector("${model.id}") - + result, message = detector.detect_watermark(watermarked_audio, sr)`; return [watermarkSnippet, detectorSnippet]; }; @@ -580,7 +580,7 @@ export const cartesia_mlx = (model: ModelData): string[] => [ import cartesia_mlx as cmx model = cmx.from_pretrained("${model.id}") -model.set_dtype(mx.float32) +model.set_dtype(mx.float32) prompt = "Rene Descartes was" @@ -703,7 +703,7 @@ export const keras = (model: ModelData): string[] => [ `# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" - + import keras model = keras.saving.load_model("hf://${model.id}") @@ -867,7 +867,7 @@ model.score("query", ["doc1", "doc2", "doc3"])`, from lightning_ir import BiEncoderModule, CrossEncoderModule # depending on the model type, use either BiEncoderModule or CrossEncoderModule -model = BiEncoderModule("${model.id}") +model = BiEncoderModule("${model.id}") # model = CrossEncoderModule("${model.id}") model.score("query", ["doc1", "doc2", "doc3"])`, @@ -916,7 +916,7 @@ pip install -e .[smolvla]`, `# Launch finetuning on your dataset python lerobot/scripts/train.py \\ --policy.path=${model.id} \\ ---dataset.repo_id=lerobot/svla_so101_pickplace \\ +--dataset.repo_id=lerobot/svla_so101_pickplace \\ --batch_size=64 \\ --steps=20000 \\ --output_dir=outputs/train/my_smolvla \\ @@ -927,7 +927,7 @@ python lerobot/scripts/train.py \\ if (model.id !== "lerobot/smolvla_base") { // Inference snippet (only if not base model) smolvlaSnippets.push( - `# Run the policy using the record function + `# Run the policy using the record function python -m lerobot.record \\ --robot.type=so101_follower \\ --robot.port=/dev/ttyACM0 \\ # <- Use your port @@ -1067,7 +1067,7 @@ for res in output: } return [ - `# Please refer to the document for information on how to use the model. + `# Please refer to the document for information on how to use the model. # https://paddlepaddle.github.io/PaddleOCR/latest/en/version3.x/module_usage/module_overview.html`, ]; }; @@ -1103,7 +1103,7 @@ wan_i2v = WanI2V( export const pyannote_audio_pipeline = (model: ModelData): string[] => [ `from pyannote.audio import Pipeline - + pipeline = Pipeline.from_pretrained("${model.id}") # inference on the whole file @@ -1142,7 +1142,7 @@ export const pyannote_audio = (model: ModelData): string[] => { export const relik = (model: ModelData): string[] => [ `from relik import Relik - + relik = Relik.from_pretrained("${model.id}")`, ]; @@ -1326,7 +1326,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): const video_predictor = `# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor - + predictor = SAM2VideoPredictor.from_pretrained(${model.id}) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): @@ -1885,6 +1885,115 @@ export const model2vec = (model: ModelData): string[] => [ model = StaticModel.from_pretrained("${model.id}")`, ]; +// Main entry point for generating Pruna code snippets based on the model's library +export const pruna = (model: ModelData): string[] => { + // If the model uses diffusers, use the diffusers-specific snippet + if (model.library_name === "diffusers") { + return pruna_diffusers(model); + // If the model uses transformers, use the transformers-specific snippet + } else if (model.library_name === "transformers") { + return pruna_transformers(model); + } + // Fallback to a default snippet + return pruna_default(model); +}; + +// Generates Pruna code snippets for diffusers models +export const pruna_diffusers = (model: ModelData): string[] => { + const diffusersSnippets = diffusers(model); + + // Replace all pipeline class names with PrunaModel + const rewrittenSnippets = diffusersSnippets.map(snippet => + snippet + // First, replace ALL pipeline class references (including in imports) + // Handle classes that end with "Pipeline" (like DiffusionPipeline) + .replace(/\b(\w+Pipeline)\b/g, "PrunaModel") + // Handle classes that contain "Pipeline" but don't end with it (like AutoPipelineForInpainting) + .replace(/\b(\w+Pipeline\w+)\b/g, "PrunaModel") + // Then handle imports - remove pipeline imports completely + .replace(/from diffusers import ([^,\n]*PrunaModel[^,\n]*)/g, "") + .replace(/from diffusers import ([^,\n]+),?\s*([^,\n]*PrunaModel[^,\n]*)/g, "from diffusers import $1") + // Clean up any empty import lines + .replace(/from diffusers import\s*\n/g, "") + .replace(/from diffusers import\s*$/g, "") + // Clean up any double newlines + .replace(/\n\n+/g, "\n") + // Fix the PrunaModel import issue - replace any remaining "from diffusers import PrunaModel" + .replace(/from diffusers import PrunaModel/g, "from pruna import PrunaModel") + // Handle cases where PrunaModel is imported alongside other classes + .replace(/from diffusers import ([^,\n]+), PrunaModel/g, "from diffusers import $1") + .replace(/from diffusers import PrunaModel, ([^,\n]+)/g, "from diffusers import $1") + .trim() + ); + + // Always add the PrunaModel import at the beginning of each snippet + const finalSnippets = rewrittenSnippets.map(snippet => { + if (!/^from pruna import PrunaModel/m.test(snippet)) { + return `from pruna import PrunaModel\n${snippet}`; + } + return snippet; + }); + + return finalSnippets; +}; + +// Generates Pruna code snippets for transformers models +export const pruna_transformers = (model: ModelData): string[] => { + const info = model.transformersInfo; + const transformersSnippets = transformers(model); + + // Replace pipeline import and usage with PrunaModel + const rewrittenSnippets = transformersSnippets.map(snippet => + snippet.replace(/from transformers import pipeline/g, "from pruna import PrunaModel") + ); + const prunaSnippets = rewrittenSnippets.map(snippet => + snippet.replace(/pipeline\([^\)]*\)/g, `PrunaModel.from_pretrained("${model.id}")`) + ); + + // If transformersInfo is not available, just return the basic replacements + if (!info) { + return prunaSnippets; + } + + // Further clean up the snippet to remove references to the original auto_model + const cleanedSnippets = prunaSnippets.map(snippet => { + let s = snippet + // Remove any import statements for the original auto_model + .replace(new RegExp(`from transformers import ${info.auto_model}\\n?`, "g"), "") + // Replace original from_pretrained calls with Pruna's + .replace(new RegExp(`${info.auto_model}.from_pretrained`, "g"), "PrunaModel.from_pretrained") + // Remove any extra auto_model arguments in function imports + // Only remove ", auto_model" if it's in a line with both "from" and "import" + .replace( + new RegExp( + `^.*from.*import.*(, *${info.auto_model})+.*$`, + "gm" + ), + line => line.replace(new RegExp(`, *${info.auto_model}`, "g"), "") + ); + return s; + }); + + // Add 'from pruna import PrunaModel' at the top if not present + const finalSnippets = cleanedSnippets.map(snippet => { + if (!/^from pruna import PrunaModel/m.test(snippet)) { + return `from pruna import PrunaModel\n${snippet}`; + } + return snippet; + }); + + return finalSnippets; +}; + +// Default Pruna snippet for unsupported or unknown libraries +export const pruna_default = (model: ModelData): string[] => [ + `from pruna import PrunaModel + +model = PrunaModel.from_pretrained("${model.id}") +`, +]; + + export const nemo = (model: ModelData): string[] => { let command: string[] | undefined = undefined; // Resolve the tag to a nemo domain/sub-domain @@ -1896,7 +2005,7 @@ export const nemo = (model: ModelData): string[] => { }; export const outetts = (model: ModelData): string[] => { - // Don’t show this block on GGUF / ONNX mirrors + // Don't show this block on GGUF / ONNX mirrors const t = model.tags ?? []; if (t.includes("gguf") || t.includes("onnx")) return []; @@ -1904,11 +2013,11 @@ export const outetts = (model: ModelData): string[] => { return [ ` import outetts - + enum = outetts.Models("${model.id}".split("/", 1)[1]) # VERSION_1_0_SIZE_1B cfg = outetts.ModelConfig.auto_config(enum, outetts.Backend.HF) tts = outetts.Interface(cfg) - + speaker = tts.load_default_speaker("EN-FEMALE-1-NEUTRAL") tts.generate( outetts.GenerationConfig( @@ -1943,7 +2052,7 @@ wav = model.generate(descriptions) # generates 3 samples.`, const magnet = (model: ModelData): string[] => [ `from audiocraft.models import MAGNeT - + model = MAGNeT.get_pretrained("${model.id}") descriptions = ['disco beat', 'energetic EDM', 'funky groove'] @@ -1952,7 +2061,7 @@ wav = model.generate(descriptions) # generates 3 samples.`, const audiogen = (model: ModelData): string[] => [ `from audiocraft.models import AudioGen - + model = AudioGen.get_pretrained("${model.id}") model.set_generation_params(duration=5) # generate 5 seconds. descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor'] @@ -1985,7 +2094,7 @@ brew install whisperkit-cli # View all available inference options whisperkit-cli transcribe --help - + # Download and run inference using whisper base model whisperkit-cli transcribe --audio-path /path/to/audio.mp3 diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index d267b199dd..bdf55ccb95 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -776,6 +776,14 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { filter: false, countDownloads: `path_extension:"pth"`, }, + "pruna-ai": { + prettyLabel: "Pruna AI", + repoName: "Pruna AI", + repoUrl: "https://github.com/Pruna-AI/pruna-ai", + snippets: snippets.pruna, + docsUrl: "https://docs.pruna.ai", + filter: false, + }, pxia: { prettyLabel: "pxia", repoName: "pxia",