Skip to content

Commit 59fef0f

Browse files
authored
Merge pull request #108 from pluflou/add-gps
Add GPModel class
2 parents 1a13363 + 2aa60a6 commit 59fef0f

31 files changed

+4165
-1811
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: lume-model
33
channels:
44
- conda-forge
55
dependencies:
6-
- python>=3.9
6+
- python>=3.10
77
- pydantic>2.3
88
- numpy
99
- pyyaml
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "235c92cd-cc05-42b8-a516-1185eeac5f0c",
6+
"metadata": {},
7+
"source": [
8+
"# Creating a Custom LUME-model for probabilistic models\n"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"id": "56725817-2b21-4bea-98b0-151dea959f77",
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"from torch.distributions.normal import Normal\n",
19+
"import torch\n",
20+
"from lume_model.models.prob_model_base import ProbModelBaseModel\n",
21+
"from lume_model.variables import ScalarVariable, DistributionVariable"
22+
]
23+
},
24+
{
25+
"cell_type": "markdown",
26+
"id": "79c62b18-7dc1-44ca-b578-4dea5cc4a4b4",
27+
"metadata": {},
28+
"source": [
29+
"## Create a model that returns a list of predictions"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 2,
35+
"id": "f96d9863-269c-49d8-9671-cc73a783bcbc",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"class ExampleModel(ProbModelBaseModel):\n",
40+
" \"\"\"Model returns a list of predictions for each output\"\"\"\n",
41+
"\n",
42+
" def _get_predictions(self, input_dict):\n",
43+
" \"\"\"\n",
44+
" This method implements the required abstract method for this class.\n",
45+
" It takes the input_dict and returns a dict of output names to distributions.\n",
46+
" \"\"\"\n",
47+
" # Just generate random output here for this example\n",
48+
" # but typically this is where you would adjust the input if needed and\n",
49+
" # call your model on the input\n",
50+
" output_dict = {\n",
51+
" \"output1\": torch.rand(5),\n",
52+
" \"output2\": torch.rand(10),\n",
53+
" }\n",
54+
" return self._create_output_dict(output_dict)\n",
55+
"\n",
56+
" def _create_output_dict(self, output):\n",
57+
" \"\"\"This method is not required by the abstract class but typically\n",
58+
" needed to create a distribution type output for each output\n",
59+
" name from the list of predictions that the model returns.\n",
60+
" \"\"\"\n",
61+
" output_dict = {}\n",
62+
" for k, v in output.items():\n",
63+
" output_dict[k] = Normal(v.mean(axis=0), torch.sqrt(v.var(axis=0)))\n",
64+
" return output_dict"
65+
]
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"id": "868fff4d-1f46-48e2-8bd0-c9d831df79e6",
70+
"metadata": {},
71+
"source": [
72+
"### Model Instantiation and Execution\n",
73+
"Instantiation requires specification of the input and output variables of the model."
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": 3,
79+
"id": "97946e64-062d-47d4-8d0c-d7e02a335a56",
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"input_variables = [\n",
84+
" ScalarVariable(name=\"input1\", default_value=0.1),\n",
85+
" ScalarVariable(name=\"input2\", default_value=0.2, value_range=[0.0, 1.0]),\n",
86+
"]\n",
87+
"output_variables = [\n",
88+
" DistributionVariable(name=\"output1\"),\n",
89+
" DistributionVariable(name=\"output2\"),\n",
90+
"]\n",
91+
"\n",
92+
"m = ExampleModel(input_variables=input_variables, output_variables=output_variables)"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 4,
98+
"id": "50aae4be-0d6e-456f-83e8-3a84d6d78f84",
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"input_dict = {\n",
103+
" \"input1\": 0.3,\n",
104+
" \"input2\": 0.6,\n",
105+
"}\n",
106+
"out = m.evaluate(input_dict)"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": 5,
112+
"id": "9a74c70a-4d9a-443f-820d-69d111e574ed",
113+
"metadata": {},
114+
"outputs": [
115+
{
116+
"data": {
117+
"text/plain": [
118+
"{'output1': Normal(loc: 0.4858802855014801, scale: 0.3480694591999054),\n",
119+
" 'output2': Normal(loc: 0.5287243127822876, scale: 0.28792139887809753)}"
120+
]
121+
},
122+
"execution_count": 5,
123+
"metadata": {},
124+
"output_type": "execute_result"
125+
}
126+
],
127+
"source": [
128+
"out"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 6,
134+
"id": "301aa223-f53f-498f-8b31-ed1008594f87",
135+
"metadata": {},
136+
"outputs": [
137+
{
138+
"data": {
139+
"text/plain": [
140+
"(tensor(0.4859), tensor(0.1212))"
141+
]
142+
},
143+
"execution_count": 6,
144+
"metadata": {},
145+
"output_type": "execute_result"
146+
}
147+
],
148+
"source": [
149+
"out[\"output1\"].mean, out[\"output1\"].variance"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": null,
155+
"id": "35ea12da-d0c6-49cc-8a00-2b096fc7248b",
156+
"metadata": {},
157+
"outputs": [],
158+
"source": []
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": null,
163+
"id": "63d56fd9-8c25-4371-921b-968773376203",
164+
"metadata": {},
165+
"outputs": [],
166+
"source": []
167+
}
168+
],
169+
"metadata": {
170+
"kernelspec": {
171+
"display_name": "Python 3 (ipykernel)",
172+
"language": "python",
173+
"name": "python3"
174+
},
175+
"language_info": {
176+
"codemirror_mode": {
177+
"name": "ipython",
178+
"version": 3
179+
},
180+
"file_extension": ".py",
181+
"mimetype": "text/x-python",
182+
"name": "python",
183+
"nbconvert_exporter": "python",
184+
"pygments_lexer": "ipython3",
185+
"version": "3.10.16"
186+
}
187+
},
188+
"nbformat": 4,
189+
"nbformat_minor": 5
190+
}

0 commit comments

Comments
 (0)