Skip to content

Commit d65b548

Browse files
authored
Merge pull request #61 from DefangLabs/jordan/avoid-nvidia-packages
2 parents 8bbf113 + 91c0519 commit d65b548

File tree

7 files changed

+23
-28
lines changed

7 files changed

+23
-28
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@ jobs:
1919
working-directory: ./app
2020
run: |
2121
docker buildx build \
22-
--platform linux/amd64,linux/arm64 \
22+
--platform linux/amd64 \
2323
.

app/.dockerignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
myenv
2+
venv
23
.direnv
34
.envrc
45
__pycache__
6+
sentence-transformers
57
.tmp

app/Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ RUN apt-get update && apt-get install -y \
1212
git \
1313
&& rm -rf /var/lib/apt/lists/*
1414

15-
# Install Go for ARM architecture (latest supported version 1.21)
16-
RUN curl -OL https://golang.org/dl/go1.21.1.linux-arm64.tar.gz && \
17-
tar -C /usr/local -xzf go1.21.1.linux-arm64.tar.gz && \
18-
rm go1.21.1.linux-arm64.tar.gz
15+
# Install Go for x86 architecture (latest supported version 1.21)
16+
RUN curl -OL https://golang.org/dl/go1.21.1.linux-amd64.tar.gz && \
17+
tar -C /usr/local -xzf go1.21.1.linux-amd64.tar.gz && \
18+
rm go1.21.1.linux-amd64.tar.gz
1919

2020
# Set Go environment variables
2121
ENV PATH="/usr/local/go/bin:${PATH}"

app/rag_system.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@ def embed_knowledge_base(self):
3030
def normalize_query(self, query):
3131
return query.lower().strip()
3232

33-
def get_query_embedding(self, query, use_cpu=True):
33+
def get_query_embedding(self, query):
3434
normalized_query = self.normalize_query(query)
3535
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
36-
if use_cpu:
37-
query_embedding = query_embedding.cpu()
36+
query_embedding = query_embedding.cpu()
3837
return query_embedding
3938

40-
def get_doc_embeddings(self, use_cpu=True):
41-
if use_cpu:
42-
return self.doc_embeddings.cpu()
43-
return self.doc_embeddings
39+
def get_doc_embeddings(self):
40+
return self.doc_embeddings.cpu()
4441

4542
def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
4643
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
@@ -66,12 +63,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th
6663

6764
return result
6865

69-
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5, use_cpu=True):
70-
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
71-
# Set use_cpu=False to leverage GPU for better performance in production.
72-
73-
query_embedding = self.get_query_embedding(query, use_cpu)
74-
doc_embeddings = self.get_doc_embeddings(use_cpu)
66+
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5):
67+
query_embedding = self.get_query_embedding(query)
68+
doc_embeddings = self.get_doc_embeddings()
7569

7670
doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
7771
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)

app/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ scikit-learn==1.2.2
55
segment-analytics-python==2.3.3
66
numpy==1.24.4
77
sentence-transformers==2.3.1
8-
torch==2.0.1
8+
--find-links https://download.pytorch.org/whl/cpu/torch_stable.html
9+
torch==2.0.1+cpu
910
huggingface_hub==0.15.1
1011
openai==0.28.0
1112
PyYAML==6.0.2

app/test_rag_system.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
5353
def test_retrieve_fallback(self):
5454
# test a query that should return the fallback response
5555
query = "Hello"
56-
# set use_cpu to True, as testing has no GPU calculations
57-
result = self.rag_system.retrieve(query, use_cpu=True)
56+
result = self.rag_system.retrieve(query)
5857
self.assertIsInstance(result, list)
5958
self.assertGreater(len(result), 0)
6059
self.assertEqual(len(result), 1) # should return one result for fallback
@@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
6766
def test_retrieve_actual_response(self):
6867
# test a query that should return an actual response from the knowledge base
6968
query = "What is Defang?"
70-
# set use_cpu to True, as testing has no GPU calculations
71-
result = self.rag_system.retrieve(query, use_cpu=True)
69+
result = self.rag_system.retrieve(query)
7270
self.assertIsInstance(result, list)
7371
self.assertGreater(len(result), 0)
7472
self.assertLessEqual(len(result), 5) # should return up to max_docs (5)
@@ -80,9 +78,8 @@ def test_retrieve_actual_response(self):
8078

8179
def test_compute_document_scores(self):
8280
query = "Does Defang have an MCP sample?"
83-
# get embeddings and move them to CPU, as testing has no GPU calculations
84-
query_embedding = self.rag_system.get_query_embedding(query, use_cpu=True)
85-
doc_embeddings = self.rag_system.get_doc_embeddings(use_cpu=True)
81+
query_embedding = self.rag_system.get_query_embedding(query)
82+
doc_embeddings = self.rag_system.get_doc_embeddings()
8683

8784
# call function and get results
8885
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)
@@ -105,4 +102,4 @@ def test_compute_document_scores(self):
105102
print("Test for compute_document_scores passed successfully!")
106103

107104
if __name__ == '__main__':
108-
unittest.main()
105+
unittest.main()

compose.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ services:
33
restart: always
44
domainname: ask.defang.io
55
x-defang-dns-role: arn:aws:iam::258338292852:role/dnsadmin-39a19c3
6+
platform: linux/amd64
67
build:
78
context: ./app
8-
shm_size: "30gb"
9+
dockerfile: Dockerfile
910
ports:
1011
- target: 5050
1112
published: 5050

0 commit comments

Comments
 (0)