Skip to content

Commit 5e5433d

Browse files
authored
Add ruff rules for return (#571)
1 parent c404853 commit 5e5433d

File tree

23 files changed

+83
-117
lines changed

23 files changed

+83
-117
lines changed

examples/evaluation/tru_shared.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def get_test_data():
6363
def init_tru():
6464
if os.getenv("TRULENS_DB_CONN_STRING"):
6565
return Tru(database_url=os.getenv("TRULENS_DB_CONN_STRING"))
66-
else:
67-
return Tru()
66+
return Tru()
6867

6968

7069
def get_feedback_functions(pipeline, golden_set):
@@ -123,15 +122,14 @@ def get_recorder(
123122
feedbacks=feedbacks,
124123
feedback_mode=feedback_mode,
125124
)
126-
elif framework == Framework.LLAMA_INDEX:
125+
if framework == Framework.LLAMA_INDEX:
127126
return TruLlama(
128127
pipeline,
129128
app_id=app_id,
130129
feedbacks=feedbacks,
131130
feedback_mode=feedback_mode,
132131
)
133-
else:
134-
raise ValueError(f"Unknown framework: {framework} specified for get_recorder()")
132+
raise ValueError(f"Unknown framework: {framework} specified for get_recorder()")
135133

136134

137135
def get_azure_chat_model(
@@ -144,7 +142,7 @@ def get_azure_chat_model(
144142
model_version=model_version,
145143
temperature=temperature,
146144
)
147-
elif framework == Framework.LLAMA_INDEX:
145+
if framework == Framework.LLAMA_INDEX:
148146
return LlamaAzureChatOpenAI(
149147
deployment_name=deployment_name,
150148
model=deployment_name,
@@ -153,27 +151,25 @@ def get_azure_chat_model(
153151
model_version=model_version,
154152
temperature=temperature,
155153
)
156-
else:
157-
raise ValueError(f"Unknown framework: {framework} specified for getChatModel()")
154+
raise ValueError(f"Unknown framework: {framework} specified for getChatModel()")
158155

159156

160157
def get_azure_embeddings_model(framework: Framework):
161158
if framework == Framework.LANG_CHAIN:
162159
return AzureOpenAIEmbeddings(
163160
azure_deployment="text-embedding-ada-002", openai_api_version="2023-05-15"
164161
)
165-
elif framework == Framework.LLAMA_INDEX:
162+
if framework == Framework.LLAMA_INDEX:
166163
return AzureOpenAIEmbedding(
167164
deployment_name="text-embedding-ada-002",
168165
model="text-embedding-ada-002",
169166
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
170167
api_version="2023-05-15",
171168
temperature=temperature,
172169
)
173-
else:
174-
raise ValueError(
175-
f"Unknown framework: {framework} specified for getEmbeddingsModel()"
176-
)
170+
raise ValueError(
171+
f"Unknown framework: {framework} specified for getEmbeddingsModel()"
172+
)
177173

178174

179175
def get_astra_vector_store(framework: Framework, collection_name: str):
@@ -184,17 +180,16 @@ def get_astra_vector_store(framework: Framework, collection_name: str):
184180
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
185181
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
186182
)
187-
elif framework == Framework.LLAMA_INDEX:
183+
if framework == Framework.LLAMA_INDEX:
188184
return AstraDBVectorStore(
189185
collection_name=collection_name,
190186
api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
191187
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
192188
embedding_dimension=1536,
193189
)
194-
else:
195-
raise ValueError(
196-
f"Unknown framework: {framework} specified for get_astra_vector_store()"
197-
)
190+
raise ValueError(
191+
f"Unknown framework: {framework} specified for get_astra_vector_store()"
192+
)
198193

199194

200195
def execute_query(framework: Framework, pipeline, query):

examples/notebooks/langchain_evaluation.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,11 @@
475475
"# create a constructor to pass in to the run_on_dataset method.\n",
476476
"# This is so any state in the chain is not reused when evaluating individual examples.\n",
477477
"def create_qa_chain(llm, vstore, return_context=True):\n",
478-
" qa_chain = RetrievalQA.from_chain_type(\n",
478+
" return RetrievalQA.from_chain_type(\n",
479479
" llm,\n",
480480
" retriever=vstore.as_retriever(),\n",
481481
" return_source_documents=return_context,\n",
482-
" )\n",
483-
" return qa_chain"
482+
" )"
484483
]
485484
},
486485
{

libs/e2e-tests/e2e_tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def get_vector_store_handler(
6464
) -> VectorStoreHandler:
6565
if vector_database_type == "astradb":
6666
return AstraDBVectorStoreHandler(implementation)
67-
elif vector_database_type == "local-cassandra":
67+
if vector_database_type == "local-cassandra":
6868
return CassandraVectorStoreHandler(implementation)
69+
raise ValueError("Invalid vector store implementation")
6970

7071

7172
failed_report_lines = []

libs/e2e-tests/e2e_tests/langchain/trulens.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,12 @@ def _initialize_tru() -> Tru:
4545

4646
def _create_chain(retriever: VectorStoreRetriever, llm: BaseLanguageModel) -> Runnable:
4747
prompt = PromptTemplate.from_template(BASIC_QA_PROMPT)
48-
chain = (
48+
return (
4949
{"context": retriever | format_docs, "question": RunnablePassthrough()}
5050
| prompt
5151
| llm
5252
| StrOutputParser()
5353
)
54-
return chain
5554

5655

5756
def run_trulens_evaluation(vector_store: VectorStore, llm: BaseLanguageModel):

libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ def run_delete(self, collection: str):
7777
running.
7878
"""
7979
self.semaphore.acquire() # Wait for a free thread
80-
future = self.executor.submit(
80+
return self.executor.submit(
8181
lambda: self._run_and_release(collection),
8282
)
83-
return future
8483

8584
def _run_and_release(self, collection: str):
8685
"""
@@ -191,14 +190,13 @@ def new_langchain_chat_memory(self, **kwargs) -> BaseChatMessageHistory:
191190
table_name=self.handler.collection_name + "_chat_memory",
192191
**kwargs,
193192
)
194-
else:
195-
return AstraDBChatMessageHistory(
196-
session_id=self.test_id,
197-
token=self.handler.token,
198-
api_endpoint=self.handler.api_endpoint,
199-
collection_name=self.handler.collection_name + "_chat_memory",
200-
**kwargs,
201-
)
193+
return AstraDBChatMessageHistory(
194+
session_id=self.test_id,
195+
token=self.handler.token,
196+
api_endpoint=self.handler.api_endpoint,
197+
collection_name=self.handler.collection_name + "_chat_memory",
198+
**kwargs,
199+
)
202200

203201
def new_llamaindex_vector_store(self, **kwargs) -> EnhancedLlamaIndexVectorStore:
204202
logging.info(

libs/e2e-tests/e2e_tests/test_utils/cassandra_vector_store_handler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,10 @@ def search_documents(self, vector: List[float], limit: int) -> List[str]:
9999
result["body_blob"]
100100
for result in self.table.ann_search(vector=vector, n=limit)
101101
]
102-
else:
103-
return [
104-
result["document"]
105-
for result in self.table.search(embedding_vector=vector, top_k=limit)
106-
]
102+
return [
103+
result["document"]
104+
for result in self.table.search(embedding_vector=vector, top_k=limit)
105+
]
107106

108107

109108
class EnhancedCassandraLlamaIndexVectorStore(

libs/knowledge-graph/ragstack_knowledge_graph/render.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ def _node_id(node: Node) -> int:
3737
node_key = (node.id, node.type)
3838
if node_id := nodes.get(node_key):
3939
return node_id
40-
else:
41-
node_id = f"{len(nodes)}"
42-
nodes[node_key] = node_id
43-
dot.node(node_id, label=_node_label(node))
44-
return node_id
40+
node_id = f"{len(nodes)}"
41+
nodes[node_key] = node_id
42+
dot.node(node_id, label=_node_label(node))
43+
return node_id
4544

4645
for graph_document in graph_documents:
4746
for node in graph_document.nodes:

libs/knowledge-graph/ragstack_knowledge_graph/traverse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,7 @@ def fetch_relationships(distance: int, source: Node) -> None:
195195

196196
if error is not None:
197197
raise error
198-
else:
199-
return results
198+
return results
200199

201200

202201
class AsyncPagedQuery:
@@ -220,8 +219,7 @@ async def next(self):
220219
self.current_page_future = asyncio.Future()
221220
self.response_future.start_fetching_next_page()
222221
return self.depth, page, self
223-
else:
224-
return self.depth, page, None
222+
return self.depth, page, None
225223

226224

227225
async def atraverse(

libs/knowledge-graph/tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def llm() -> BaseChatModel:
6060
except ValueError:
6161
pytest.skip("Unable to create OpenAI model")
6262
else:
63-
model = ChatOpenAI(model_name="gpt-4o", temperature=0.0)
64-
return model
63+
return ChatOpenAI(model_name="gpt-4o", temperature=0.0)
6564

6665

6766
class DataFixture:

libs/knowledge-store/notebooks/astra_support.ipynb

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,9 @@
118118
"def select_content(soup: BeautifulSoup, url: str) -> BeautifulSoup:\n",
119119
" if url.startswith(\"https://docs.datastax.com/en/\"):\n",
120120
" return soup.select_one(\"article.doc\")\n",
121-
" elif url.startswith(\"https://github.com\"):\n",
121+
" if url.startswith(\"https://github.com\"):\n",
122122
" return soup.select_one(\"article.entry-content\")\n",
123-
" else:\n",
124-
" return soup\n",
123+
" return soup\n",
125124
"\n",
126125
"\n",
127126
"async def load_pages(urls: Iterable[str]) -> AsyncIterator[Document]:\n",
@@ -326,10 +325,9 @@
326325
"\n",
327326
"\n",
328327
"def format_docs(docs):\n",
329-
" formatted = \"\\n\\n\".join(\n",
328+
" return \"\\n\\n\".join(\n",
330329
" f\"From {doc.metadata['content_id']}: {doc.page_content}\" for doc in docs\n",
331-
" )\n",
332-
" return formatted"
330+
" )"
333331
]
334332
},
335333
{

0 commit comments

Comments
 (0)