Skip to content

Add Pruna AI library snippets and integration #1684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 126 additions & 17 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ from audioseal import AudioSeal
model = AudioSeal.load_generator("${model.id}")
# pass a tensor (tensor_wav) of shape (batch, channels, samples) and a sample rate
wav, sr = tensor_wav, 16000

watermark = model.get_watermark(wav, sr)
watermarked_audio = wav + watermark`;

const detectorSnippet = `# Watermark Detector
from audioseal import AudioSeal

detector = AudioSeal.load_detector("${model.id}")

result, message = detector.detect_watermark(watermarked_audio, sr)`;
return [watermarkSnippet, detectorSnippet];
};
Expand Down Expand Up @@ -580,7 +580,7 @@ export const cartesia_mlx = (model: ModelData): string[] => [
import cartesia_mlx as cmx

model = cmx.from_pretrained("${model.id}")
model.set_dtype(mx.float32)
model.set_dtype(mx.float32)

prompt = "Rene Descartes was"

Expand Down Expand Up @@ -703,7 +703,7 @@ export const keras = (model: ModelData): string[] => [
`# Available backend options are: "jax", "torch", "tensorflow".
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras

model = keras.saving.load_model("hf://${model.id}")
Expand Down Expand Up @@ -867,7 +867,7 @@ model.score("query", ["doc1", "doc2", "doc3"])`,
from lightning_ir import BiEncoderModule, CrossEncoderModule

# depending on the model type, use either BiEncoderModule or CrossEncoderModule
model = BiEncoderModule("${model.id}")
model = BiEncoderModule("${model.id}")
# model = CrossEncoderModule("${model.id}")

model.score("query", ["doc1", "doc2", "doc3"])`,
Expand Down Expand Up @@ -916,7 +916,7 @@ pip install -e .[smolvla]`,
`# Launch finetuning on your dataset
python lerobot/scripts/train.py \\
--policy.path=${model.id} \\
--dataset.repo_id=lerobot/svla_so101_pickplace \\
--dataset.repo_id=lerobot/svla_so101_pickplace \\
--batch_size=64 \\
--steps=20000 \\
--output_dir=outputs/train/my_smolvla \\
Expand All @@ -927,7 +927,7 @@ python lerobot/scripts/train.py \\
if (model.id !== "lerobot/smolvla_base") {
// Inference snippet (only if not base model)
smolvlaSnippets.push(
`# Run the policy using the record function
`# Run the policy using the record function
python -m lerobot.record \\
--robot.type=so101_follower \\
--robot.port=/dev/ttyACM0 \\ # <- Use your port
Expand Down Expand Up @@ -1067,7 +1067,7 @@ for res in output:
}

return [
`# Please refer to the document for information on how to use the model.
`# Please refer to the document for information on how to use the model.
# https://paddlepaddle.github.io/PaddleOCR/latest/en/version3.x/module_usage/module_overview.html`,
];
};
Expand Down Expand Up @@ -1103,7 +1103,7 @@ wan_i2v = WanI2V(

export const pyannote_audio_pipeline = (model: ModelData): string[] => [
`from pyannote.audio import Pipeline

pipeline = Pipeline.from_pretrained("${model.id}")

# inference on the whole file
Expand Down Expand Up @@ -1142,7 +1142,7 @@ export const pyannote_audio = (model: ModelData): string[] => {

export const relik = (model: ModelData): string[] => [
`from relik import Relik

relik = Relik.from_pretrained("${model.id}")`,
];

Expand Down Expand Up @@ -1326,7 +1326,7 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
const video_predictor = `# Use SAM2 with videos
import torch
from sam2.sam2_video_predictor import SAM2VideoPredictor

predictor = SAM2VideoPredictor.from_pretrained(${model.id})

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
Expand Down Expand Up @@ -1885,6 +1885,115 @@ export const model2vec = (model: ModelData): string[] => [
model = StaticModel.from_pretrained("${model.id}")`,
];

// Main entry point for generating Pruna code snippets based on the model's library
export const pruna = (model: ModelData): string[] => {
// If the model uses diffusers, use the diffusers-specific snippet
if (model.library_name === "diffusers") {
return pruna_diffusers(model);
// If the model uses transformers, use the transformers-specific snippet
} else if (model.library_name === "transformers") {
return pruna_transformers(model);
}
// Fallback to a default snippet
return pruna_default(model);
};

// Generates Pruna code snippets for diffusers models
export const pruna_diffusers = (model: ModelData): string[] => {
const diffusersSnippets = diffusers(model);

// Replace all pipeline class names with PrunaModel
const rewrittenSnippets = diffusersSnippets.map(snippet =>
snippet
// First, replace ALL pipeline class references (including in imports)
// Handle classes that end with "Pipeline" (like DiffusionPipeline)
.replace(/\b(\w+Pipeline)\b/g, "PrunaModel")
// Handle classes that contain "Pipeline" but don't end with it (like AutoPipelineForInpainting)
.replace(/\b(\w+Pipeline\w+)\b/g, "PrunaModel")
// Then handle imports - remove pipeline imports completely
.replace(/from diffusers import ([^,\n]*PrunaModel[^,\n]*)/g, "")
.replace(/from diffusers import ([^,\n]+),?\s*([^,\n]*PrunaModel[^,\n]*)/g, "from diffusers import $1")
// Clean up any empty import lines
.replace(/from diffusers import\s*\n/g, "")
.replace(/from diffusers import\s*$/g, "")
// Clean up any double newlines
.replace(/\n\n+/g, "\n")
// Fix the PrunaModel import issue - replace any remaining "from diffusers import PrunaModel"
.replace(/from diffusers import PrunaModel/g, "from pruna import PrunaModel")
// Handle cases where PrunaModel is imported alongside other classes
.replace(/from diffusers import ([^,\n]+), PrunaModel/g, "from diffusers import $1")
.replace(/from diffusers import PrunaModel, ([^,\n]+)/g, "from diffusers import $1")
.trim()
);

// Always add the PrunaModel import at the beginning of each snippet
const finalSnippets = rewrittenSnippets.map(snippet => {
if (!/^from pruna import PrunaModel/m.test(snippet)) {
return `from pruna import PrunaModel\n${snippet}`;
}
return snippet;
});

return finalSnippets;
};

// Generates Pruna code snippets for transformers models
export const pruna_transformers = (model: ModelData): string[] => {
const info = model.transformersInfo;
const transformersSnippets = transformers(model);

// Replace pipeline import and usage with PrunaModel
const rewrittenSnippets = transformersSnippets.map(snippet =>
snippet.replace(/from transformers import pipeline/g, "from pruna import PrunaModel")
);
const prunaSnippets = rewrittenSnippets.map(snippet =>
snippet.replace(/pipeline\([^\)]*\)/g, `PrunaModel.from_pretrained("${model.id}")`)
);

// If transformersInfo is not available, just return the basic replacements
if (!info) {
return prunaSnippets;
}

// Further clean up the snippet to remove references to the original auto_model
const cleanedSnippets = prunaSnippets.map(snippet => {
let s = snippet
// Remove any import statements for the original auto_model
.replace(new RegExp(`from transformers import ${info.auto_model}\\n?`, "g"), "")
// Replace original from_pretrained calls with Pruna's
.replace(new RegExp(`${info.auto_model}.from_pretrained`, "g"), "PrunaModel.from_pretrained")
// Remove any extra auto_model arguments in function imports
// Only remove ", auto_model" if it's in a line with both "from" and "import"
.replace(
new RegExp(
`^.*from.*import.*(, *${info.auto_model})+.*$`,
"gm"
),
line => line.replace(new RegExp(`, *${info.auto_model}`, "g"), "")
);
return s;
});

// Add 'from pruna import PrunaModel' at the top if not present
const finalSnippets = cleanedSnippets.map(snippet => {
if (!/^from pruna import PrunaModel/m.test(snippet)) {
return `from pruna import PrunaModel\n${snippet}`;
}
return snippet;
});

return finalSnippets;
};

// Default Pruna snippet for unsupported or unknown libraries
export const pruna_default = (model: ModelData): string[] => [
`from pruna import PrunaModel

model = PrunaModel.from_pretrained("${model.id}")
`,
];


export const nemo = (model: ModelData): string[] => {
let command: string[] | undefined = undefined;
// Resolve the tag to a nemo domain/sub-domain
Expand All @@ -1896,19 +2005,19 @@ export const nemo = (model: ModelData): string[] => {
};

export const outetts = (model: ModelData): string[] => {
// Dont show this block on GGUF / ONNX mirrors
// Don't show this block on GGUF / ONNX mirrors
const t = model.tags ?? [];
if (t.includes("gguf") || t.includes("onnx")) return [];

// v1.0 HF → minimal runnable snippet
return [
`
import outetts

enum = outetts.Models("${model.id}".split("/", 1)[1]) # VERSION_1_0_SIZE_1B
cfg = outetts.ModelConfig.auto_config(enum, outetts.Backend.HF)
tts = outetts.Interface(cfg)

speaker = tts.load_default_speaker("EN-FEMALE-1-NEUTRAL")
tts.generate(
outetts.GenerationConfig(
Expand Down Expand Up @@ -1943,7 +2052,7 @@ wav = model.generate(descriptions) # generates 3 samples.`,

const magnet = (model: ModelData): string[] => [
`from audiocraft.models import MAGNeT

model = MAGNeT.get_pretrained("${model.id}")

descriptions = ['disco beat', 'energetic EDM', 'funky groove']
Expand All @@ -1952,7 +2061,7 @@ wav = model.generate(descriptions) # generates 3 samples.`,

const audiogen = (model: ModelData): string[] => [
`from audiocraft.models import AudioGen

model = AudioGen.get_pretrained("${model.id}")
model.set_generation_params(duration=5) # generate 5 seconds.
descriptions = ['dog barking', 'sirene of an emergency vehicle', 'footsteps in a corridor']
Expand Down Expand Up @@ -1985,7 +2094,7 @@ brew install whisperkit-cli

# View all available inference options
whisperkit-cli transcribe --help

# Download and run inference using whisper base model
whisperkit-cli transcribe --audio-path /path/to/audio.mp3

Expand Down
8 changes: 8 additions & 0 deletions packages/tasks/src/model-libraries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,14 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
filter: false,
countDownloads: `path_extension:"pth"`,
},
"pruna-ai": {
prettyLabel: "Pruna AI",
repoName: "Pruna AI",
repoUrl: "https://github.com/Pruna-AI/pruna-ai",
snippets: snippets.pruna,
docsUrl: "https://docs.pruna.ai",
filter: false,
},
pxia: {
prettyLabel: "pxia",
repoName: "pxia",
Expand Down
Loading