Skip to content

Commit 7cfc615

Browse files
Merge pull request #4 from couchbase-examples/update_defaults_in_demo
changes suggested by Nithish
2 parents 5aae639 + 9c53bf6 commit 7cfc615

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

Demo.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import streamlit as st
22
from couchbase_streamlit_connector.connector import CouchbaseConnector
3+
from couchbase.options import QueryOptions
34
import pandas as pd
45
import plotly.graph_objects as go
56
import plotly.express as px
@@ -22,30 +23,21 @@ def get_all_airports(_connection):
2223

2324
@st.cache_data
2425
def get_routes_for_airports(_connection, selected_airports_df):
25-
airports_faa = "["
26-
for i in range(len(selected_airports_df)):
27-
if i != len(selected_airports_df) - 1:
28-
airports_faa += f'"{(selected_airports_df.iloc[i])["faa"]}", '
29-
else:
30-
airports_faa += f'"{(selected_airports_df.iloc[i])["faa"]}"'
31-
airports_faa += "]"
26+
airports_faa = str(selected_airports_df["faa"].to_list()) # Initialize a string to store FAA codes in a list format
3227
query = f"""
3328
SELECT * FROM `travel-sample`.`inventory`.`route`
34-
WHERE (sourceairport IN {airports_faa} AND destinationairport IN {airports_faa});
29+
WHERE (sourceairport IN $airports_faa AND destinationairport IN $airports_faa);
3530
"""
36-
result = _connection.query(query)
31+
result = _connection.query(query, opts=QueryOptions(named_parameters={"airports_faa": airports_faa}))
3732
data = []
3833
for row in result:
3934
data.append(row["route"])
4035
return pd.DataFrame(data)
4136

4237
def plot_airports_and_routes(airports_df, routes_df):
4338
fig = go.Figure()
44-
airport_coords = {
45-
row["faa"]: (row["lat"], row["lon"])
46-
for _, row in airports_df.iterrows()
47-
if row["faa"] is not None # Ensure faa is not null
48-
}
39+
filtered_airports_df = airports_df.dropna(subset=["faa"]) # Remove rows where faa is NaN
40+
airport_coords = dict(zip(filtered_airports_df["faa"], zip(filtered_airports_df["lat"], filtered_airports_df["lon"])))
4941
lats = []
5042
lons = []
5143
for _, row in routes_df.iterrows():
@@ -75,7 +67,10 @@ def plot_airports_and_routes(airports_df, routes_df):
7567
color_discrete_sequence=["red"], # Color of airport markers
7668
)
7769
fig.add_traces(airports_markers.data)
70+
fig.update_geos(fitbounds="locations")
7871
fig.update_layout(
72+
map_zoom= 0.5, # Zoom level
73+
showlegend= False, # Hide legend
7974
mapbox_style="open-street-map",
8075
margin=dict(l=0, r=0, t=50, b=0), # Remove extra margins
8176
title="Airports and Flight Routes"
@@ -202,7 +197,11 @@ def get_hotels_near_landmark(_connection, landmark_lat, landmark_lon, max_distan
202197
return hotels
203198

204199
def create_landmark_map(landmarks, hotels_near_landmark):
205-
fig = go.Figure()
200+
fig = go.Figure()
201+
202+
centre = {"lat": 0, "lon": 0}
203+
num_points = 0
204+
206205
for hotel in hotels_near_landmark:
207206
color = 'red' if hotel.get('distance') <= 3 else 'orange' if hotel.get('distance') <= 6 else 'gold'
208207
fig.add_trace(go.Scattermap(
@@ -216,6 +215,8 @@ def create_landmark_map(landmarks, hotels_near_landmark):
216215
hoverinfo='text',
217216
name=f'Hotel ({color})'
218217
))
218+
centre = {"lat": centre["lat"] + hotel.get('lat'), "lon": centre["lon"] + hotel.get('lon')}
219+
num_points += 1
219220

220221
for landmark in landmarks:
221222
fig.add_trace(go.Scattermap(
@@ -229,8 +230,16 @@ def create_landmark_map(landmarks, hotels_near_landmark):
229230
hoverinfo='text',
230231
name='Landmark'
231232
))
233+
centre = {"lat": centre["lat"] + landmark.get('lat', 0), "lon": centre["lon"] + landmark.get('lon', 0)}
234+
num_points += 1
235+
236+
if num_points > 0:
237+
centre = {"lat": centre["lat"] / num_points, "lon": centre["lon"] / num_points}
238+
fig.update_geos(fitbounds="locations")
232239

233240
fig.update_layout(
241+
map_zoom=11,
242+
map_center=centre,
234243
mapbox_style='open-street-map',
235244
margin=dict(l=0, r=0, t=50, b=0),
236245
title='Landmarks and Hotels Nearby',
@@ -275,22 +284,15 @@ def get_all_cities(_connection):
275284

276285
@st.cache_data
277286
def get_all_hotels(_connection, cities):
278-
cities_str = "["
279-
for i in range(len(cities)):
280-
if i != len(cities) - 1:
281-
cities_str += f'"{cities[i]}", '
282-
else:
283-
cities_str += f'"{cities[i]}"'
284-
cities_str += "]"
285287
query = f"""
286288
SELECT h.*, geo.lat as lat, geo.lon as lon, ARRAY_AVG(ARRAY r.ratings.Overall FOR r IN h.reviews WHEN r.ratings.Overall IS NOT MISSING END) as avg_rating
287289
FROM `travel-sample`.inventory.hotel h
288290
WHERE h.geo.lat IS NOT MISSING
289291
AND h.type = "hotel"
290292
AND h.geo.lon IS NOT MISSING
291-
AND h.city IN {cities_str}
293+
AND h.city IN $cities;
292294
"""
293-
result = _connection.query(query)
295+
result = _connection.query(query, opts=QueryOptions(named_parameters={"cities": cities}))
294296
hotels = []
295297
for row in result:
296298
hotels.append(row)
@@ -318,6 +320,10 @@ def create_hotel_map(hotels_df):
318320
if 'avg_rating' not in hotels_df.columns:
319321
hotels_df['avg_rating'] = np.nan # Add avg_rating column if it doesn't exist
320322
hotels_df['avg_rating'] = pd.to_numeric(hotels_df['avg_rating'], errors='coerce')
323+
centre = {
324+
"lat": hotels_df['lat'].mean(),
325+
"lon": hotels_df['lon'].mean()
326+
}
321327

322328
# Create a column for star ratings
323329
hotels_df['star_rating'] = hotels_df['avg_rating'].apply(lambda x: '⭐' * int(round(x)) if not np.isnan(x) else 'No rating')
@@ -363,6 +369,8 @@ def create_hotel_map(hotels_df):
363369
fig.add_traces(no_rating_markers.data)
364370

365371
fig.update_layout(
372+
map_zoom=10,
373+
map_center=centre,
366374
mapbox_style="open-street-map",
367375
margin=dict(l=0, r=0, t=50, b=0),
368376
title="Hotels (colored by average rating)",
@@ -373,18 +381,11 @@ def create_hotel_map(hotels_df):
373381
)
374382
)
375383

376-
fig.update_layout(
377-
mapbox_style="open-street-map",
378-
margin=dict(l=0, r=0, t=50, b=0),
379-
title="Hotels (colored by average rating)",
380-
coloraxis_colorbar_title="Avg Rating"
381-
)
382-
383384
st.plotly_chart(fig, use_container_width=True)
384385

385386
def tab3_visual():
386387
all_cities = get_all_cities(connection)["city"].tolist()
387-
cities = st.multiselect("Select cities", all_cities, default=["Newport", "Birmingham", "London"])
388+
cities = st.multiselect("Select cities", all_cities, default=["London"])
388389
hotels = get_all_hotels(connection, cities)
389390
create_hotel_map(hotels)
390391

0 commit comments

Comments
 (0)