|
17 | 17 |
|
18 | 18 | from ragas.callbacks import new_group |
19 | 19 | from ragas.dataset_schema import MultiTurnSample, SingleTurnSample |
| 20 | +from ragas.executor import is_event_loop_running |
20 | 21 | from ragas.run_config import RunConfig |
21 | 22 | from ragas.utils import deprecated |
22 | 23 |
|
@@ -59,8 +60,7 @@ class Metric(ABC): |
59 | 60 |
|
60 | 61 | @property |
61 | 62 | @abstractmethod |
62 | | - def name(self) -> str: |
63 | | - ... |
| 63 | + def name(self) -> str: ... |
64 | 64 |
|
65 | 65 | @property |
66 | 66 | def required_columns(self) -> t.Dict[str, t.Set[str]]: |
@@ -103,6 +103,15 @@ def score(self: t.Self, row: t.Dict, callbacks: Callbacks = None) -> float: |
103 | 103 | callbacks = callbacks or [] |
104 | 104 | rm, group_cm = new_group(self.name, inputs=row, callbacks=callbacks) |
105 | 105 | try: |
| 106 | + if is_event_loop_running(): |
| 107 | + try: |
| 108 | + import nest_asyncio |
| 109 | + |
| 110 | + nest_asyncio.apply() |
| 111 | + except ImportError: |
| 112 | + raise ImportError( |
| 113 | + "It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work." |
| 114 | + ) |
106 | 115 | loop = asyncio.get_event_loop() |
107 | 116 | score = loop.run_until_complete(self._ascore(row=row, callbacks=group_cm)) |
108 | 117 | except Exception as e: |
@@ -138,8 +147,7 @@ async def ascore( |
138 | 147 | return score |
139 | 148 |
|
140 | 149 | @abstractmethod |
141 | | - async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: |
142 | | - ... |
| 150 | + async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ... |
143 | 151 |
|
144 | 152 |
|
145 | 153 | @dataclass |
@@ -246,8 +254,7 @@ async def _single_turn_ascore( |
246 | 254 | self, |
247 | 255 | sample: SingleTurnSample, |
248 | 256 | callbacks: Callbacks, |
249 | | - ) -> float: |
250 | | - ... |
| 257 | + ) -> float: ... |
251 | 258 |
|
252 | 259 |
|
253 | 260 | class MultiTurnMetric(Metric): |
@@ -299,8 +306,7 @@ async def _multi_turn_ascore( |
299 | 306 | self, |
300 | 307 | sample: MultiTurnSample, |
301 | 308 | callbacks: Callbacks, |
302 | | - ) -> float: |
303 | | - ... |
| 309 | + ) -> float: ... |
304 | 310 |
|
305 | 311 |
|
306 | 312 | class Ensember: |
|
0 commit comments