|
44 | 44 | ) |
45 | 45 | from pyspark.sql.functions import udf |
46 | 46 | from typing import cast, Optional, TypeVar, Type |
| 47 | +from synapse.ml.core.platform import running_on_synapse_internal |
47 | 48 |
|
48 | 49 | OPENAI_API_VERSION = "2022-12-01" |
49 | 50 | RL = TypeVar("RL", bound="MLReadable") |
@@ -125,6 +126,14 @@ def __init__( |
125 | 126 | self.subscriptionKey = Param(self, "subscriptionKey", "openai api key") |
126 | 127 | self.url = Param(self, "url", "openai api base") |
127 | 128 | self.apiVersion = Param(self, "apiVersion", "openai api version") |
| 129 | + self.running_on_synapse_internal = running_on_synapse_internal() |
| 130 | + if running_on_synapse_internal(): |
| 131 | + from synapse.ml.fabric.service_discovery import get_fabric_env_config |
| 132 | + |
| 133 | + self._setDefault( |
| 134 | + url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint |
| 135 | + + "cognitive/openai" |
| 136 | + ) |
128 | 137 | kwargs = self._input_kwargs |
129 | 138 | if subscriptionKey: |
130 | 139 | kwargs["subscriptionKey"] = subscriptionKey |
@@ -196,10 +205,15 @@ def _transform(self, dataset): |
196 | 205 | def udfFunction(x): |
197 | 206 | import openai |
198 | 207 |
|
199 | | - openai.api_type = "azure" |
200 | | - openai.api_key = self.getSubscriptionKey() |
201 | | - openai.api_base = self.getUrl() |
202 | | - openai.api_version = self.getApiVersion() |
| 208 | + if self.running_on_synapse_internal and not self.isSet(self.url): |
| 209 | + from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun |
| 210 | + |
| 211 | + OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None) |
| 212 | + else: |
| 213 | + openai.api_type = "azure" |
| 214 | + openai.api_key = self.getSubscriptionKey() |
| 215 | + openai.api_base = self.getUrl() |
| 216 | + openai.api_version = self.getApiVersion() |
203 | 217 | return self.getChain().run(x) |
204 | 218 |
|
205 | 219 | outCol = self.getOutputCol() |
|
0 commit comments