Skip to content

Commit edef2d3

Browse files
committed
Add ConTextTab model
1 parent d5e865f commit edef2d3

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,60 @@ wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
132132
ta.save("test-2.wav", wav, model.sr)`,
133133
];
134134

135+
export const contexttab = (): string[] => [
136+
`# pip install git+https://github.com/SAP-samples/contexttab
137+
138+
# Run a classification task
139+
from sklearn.datasets import load_breast_cancer
140+
from sklearn.metrics import accuracy_score
141+
from sklearn.model_selection import train_test_split
142+
143+
from contexttab import ConTextTabClassifier
144+
145+
# Load sample data
146+
X, y = load_breast_cancer(return_X_y=True)
147+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
148+
149+
# Initialize a classifier
150+
clf = ConTextTabClassifier(bagging=1, max_context_size=2048)
151+
152+
clf.fit(X_train, y_train)
153+
154+
# Predict probabilities
155+
prediction_probabilities = clf.predict_proba(X_test)
156+
# Predict labels
157+
predictions = clf.predict(X_test)
158+
print("Accuracy", accuracy_score(y_test, predictions))
159+
160+
# Run a regression task
161+
from sklearn.datasets import fetch_openml
162+
from sklearn.metrics import r2_score
163+
from sklearn.model_selection import train_test_split
164+
165+
from contexttab import ConTextTabRegressor
166+
167+
168+
# Load sample data
169+
df = fetch_openml(data_id=531, as_frame=True)
170+
X = df.data
171+
y = df.target.astype(float)
172+
173+
# Train-test split
174+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
175+
176+
# Initialize the regressor
177+
regressor = ConTextTabRegressor(bagging=1, max_context_size=2048)
178+
179+
regressor.fit(X_train, y_train)
180+
181+
# Predict on the test set
182+
predictions = regressor.predict(X_test)
183+
184+
r2 = r2_score(y_test, predictions)
185+
print("R² Score:", r2)`,
186+
];
187+
188+
135189
export const cxr_foundation = (): string[] => [
136190
`# pip install git+https://github.com/Google-Health/cxr-foundation.git#subdirectory=python
137191

packages/tasks/src/model-libraries.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = {
208208
repoUrl: "https://github.com/Unbabel/COMET/",
209209
countDownloads: `path:"hparams.yaml"`,
210210
},
211+
contexttab: {
212+
prettyLabel: "ConTextTab",
213+
repoName: "contexttab",
214+
repoUrl: "https://github.com/SAP-samples/contexttab",
215+
countDownloads: `path_extension:"pt"`,
216+
snippets: snippets.contexttab,
217+
},
211218
cosmos: {
212219
prettyLabel: "Cosmos",
213220
repoName: "Cosmos",

0 commit comments

Comments
 (0)