Skip to content

Commit 8f794c8

Browse files
lhrotkmslhrotk
andauthored
feat: Support langchain transformer on fabric (microsoft#2036)
* support langchain transformer on fabric * avoid addtional param * format code --------- Co-authored-by: cruise <cruiseli@microsoft.com>
1 parent 149c634 commit 8f794c8

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from pyspark.sql.functions import udf
4646
from typing import cast, Optional, TypeVar, Type
47+
from synapse.ml.core.platform import running_on_synapse_internal
4748

4849
OPENAI_API_VERSION = "2022-12-01"
4950
RL = TypeVar("RL", bound="MLReadable")
@@ -125,6 +126,14 @@ def __init__(
125126
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
126127
self.url = Param(self, "url", "openai api base")
127128
self.apiVersion = Param(self, "apiVersion", "openai api version")
129+
self.running_on_synapse_internal = running_on_synapse_internal()
130+
if running_on_synapse_internal():
131+
from synapse.ml.fabric.service_discovery import get_fabric_env_config
132+
133+
self._setDefault(
134+
url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint
135+
+ "cognitive/openai"
136+
)
128137
kwargs = self._input_kwargs
129138
if subscriptionKey:
130139
kwargs["subscriptionKey"] = subscriptionKey
@@ -196,10 +205,15 @@ def _transform(self, dataset):
196205
def udfFunction(x):
197206
import openai
198207

199-
openai.api_type = "azure"
200-
openai.api_key = self.getSubscriptionKey()
201-
openai.api_base = self.getUrl()
202-
openai.api_version = self.getApiVersion()
208+
if self.running_on_synapse_internal and not self.isSet(self.url):
209+
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun
210+
211+
OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
212+
else:
213+
openai.api_type = "azure"
214+
openai.api_key = self.getSubscriptionKey()
215+
openai.api_base = self.getUrl()
216+
openai.api_version = self.getApiVersion()
203217
return self.getChain().run(x)
204218

205219
outCol = self.getOutputCol()

0 commit comments

Comments
 (0)