diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 08921cf533..212686160e 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -132,6 +132,62 @@ wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH) ta.save("test-2.wav", wav, model.sr)`, ]; +export const contexttab = (): string[] => { + const installSnippet = `pip install git+https://github.com/SAP-samples/contexttab`; + + const classificationSnippet = `# Run a classification task +from sklearn.datasets import load_breast_cancer +from sklearn.metrics import accuracy_score +from sklearn.model_selection import train_test_split + +from contexttab import ConTextTabClassifier + +# Load sample data +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) + +# Initialize a classifier +# You can omit checkpoint and checkpoint_revision to use the default model +clf = ConTextTabClassifier(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048) + +clf.fit(X_train, y_train) + +# Predict probabilities +prediction_probabilities = clf.predict_proba(X_test) +# Predict labels +predictions = clf.predict(X_test) +print("Accuracy", accuracy_score(y_test, predictions))`; + + const regressionsSnippet = `# Run a regression task +from sklearn.datasets import fetch_openml +from sklearn.metrics import r2_score +from sklearn.model_selection import train_test_split + +from contexttab import ConTextTabRegressor + + +# Load sample data +df = fetch_openml(data_id=531, as_frame=True) +X = df.data +y = df.target.astype(float) + +# Train-test split +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) + +# Initialize the regressor +# You can omit checkpoint and checkpoint_revision to use the default model +regressor = ConTextTabRegressor(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048) + +regressor.fit(X_train, y_train) + +# Predict on the test set +predictions = regressor.predict(X_test) + +r2 = r2_score(y_test, predictions) +print("R² Score:", r2)`; + return [installSnippet, classificationSnippet, regressionsSnippet]; +}; + export const cxr_foundation = (): string[] => [ `# pip install git+https://github.com/Google-Health/cxr-foundation.git#subdirectory=python diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index b51e19f3d6..9a961b8955 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -208,6 +208,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { repoUrl: "https://github.com/Unbabel/COMET/", countDownloads: `path:"hparams.yaml"`, }, + contexttab: { + prettyLabel: "ConTextTab", + repoName: "ConTextTab", + repoUrl: "https://github.com/SAP-samples/contexttab", + countDownloads: `path_extension:"pt"`, + snippets: snippets.contexttab, + }, cosmos: { prettyLabel: "Cosmos", repoName: "Cosmos",