@@ -132,6 +132,62 @@ wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)
132132ta.save("test-2.wav", wav, model.sr)` ,
133133] ;
134134
135+ export const contexttab = ( ) : string [ ] => {
136+ const installSnippet = `pip install git+https://github.com/SAP-samples/contexttab` ;
137+
138+ const classificationSnippet = `# Run a classification task
139+ from sklearn.datasets import load_breast_cancer
140+ from sklearn.metrics import accuracy_score
141+ from sklearn.model_selection import train_test_split
142+
143+ from contexttab import ConTextTabClassifier
144+
145+ # Load sample data
146+ X, y = load_breast_cancer(return_X_y=True)
147+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
148+
149+ # Initialize a classifier
150+ # You can omit checkpoint and checkpoint_revision to use the default model
151+ clf = ConTextTabClassifier(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
152+
153+ clf.fit(X_train, y_train)
154+
155+ # Predict probabilities
156+ prediction_probabilities = clf.predict_proba(X_test)
157+ # Predict labels
158+ predictions = clf.predict(X_test)
159+ print("Accuracy", accuracy_score(y_test, predictions))` ;
160+
161+ const regressionsSnippet = `# Run a regression task
162+ from sklearn.datasets import fetch_openml
163+ from sklearn.metrics import r2_score
164+ from sklearn.model_selection import train_test_split
165+
166+ from contexttab import ConTextTabRegressor
167+
168+
169+ # Load sample data
170+ df = fetch_openml(data_id=531, as_frame=True)
171+ X = df.data
172+ y = df.target.astype(float)
173+
174+ # Train-test split
175+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
176+
177+ # Initialize the regressor
178+ # You can omit checkpoint and checkpoint_revision to use the default model
179+ regressor = ConTextTabRegressor(checkpoint="l2/base.pt", checkpoint_revision="v1.0.0", bagging=1, max_context_size=2048)
180+
181+ regressor.fit(X_train, y_train)
182+
183+ # Predict on the test set
184+ predictions = regressor.predict(X_test)
185+
186+ r2 = r2_score(y_test, predictions)
187+ print("R² Score:", r2)` ;
188+ return [ installSnippet , classificationSnippet , regressionsSnippet ] ;
189+ } ;
190+
135191export const cxr_foundation = ( ) : string [ ] => [
136192 `# pip install git+https://github.com/Google-Health/cxr-foundation.git#subdirectory=python
137193
0 commit comments