Skip to content

Commit 2e4d048

Browse files
Added new file format param to skops snippet script (#417)
Co-authored-by: Omar Sanseviero <[email protected]>
1 parent 55c6353 commit 2e4d048

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

js/src/lib/interfaces/Libraries.ts

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ const paddlenlp = (model: ModelData) => {
161161
return [
162162
`from paddlenlp.transformers import AutoModel, AutoTokenizer`,
163163
"",
164-
`tokenizer = AutoTokenizer.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""}, from_hf_hub=True)`,
164+
`tokenizer = AutoTokenizer.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""}, from_hf_hub=True)`,
165165
`model = AutoModel.from_pretrained("${model.id}"${model.private ? ", use_auth_token=True" : ""}, from_hf_hub=True)`,
166-
].join("\n");
166+
].join("\n");
167167
};
168168

169169
const pyannote_audio_pipeline = (model: ModelData) =>
@@ -240,20 +240,32 @@ model = timm.create_model("hf_hub:${model.id}", pretrained=True)`;
240240
const sklearn = (model: ModelData) => {
241241
if (model.tags?.includes("skops")) {
242242
const skopsmodelFile = model.config?.sklearn?.filename;
243-
return `from skops.hub_utils import download
243+
const skopssaveFormat = model.config?.sklearn?.model_format;
244+
if (skopssaveFormat === "pickle") {
245+
return `import joblib
246+
from skops.hub_utils import download
247+
download("${model.id}", "path_to_folder")
248+
model = joblib.load(
249+
"${skopsmodelFile}"
250+
)
251+
# only load pickle files from sources you trust
252+
# read more about it here https://skops.readthedocs.io/en/stable/persistence.html`;
253+
} else {
254+
return `from skops.hub_utils import download
244255
from skops.io import load
245-
246256
download("${model.id}", "path_to_folder")
247257
# make sure model file is in skops format
248258
# if model is a pickle file, make sure it's from a source you trust
249259
model = load("path_to_folder/${skopsmodelFile}")`;
260+
}
250261
} else {
251262
return `from huggingface_hub import hf_hub_download
252263
import joblib
253-
254264
model = joblib.load(
255265
hf_hub_download("${model.id}", "sklearn_model.joblib")
256-
)`;
266+
)
267+
# only load pickle files from sources you trust
268+
# read more about it here https://skops.readthedocs.io/en/stable/persistence.html`;
257269
}
258270
};
259271

0 commit comments

Comments
 (0)