diff --git a/Demo.py b/Demo.py index 303fcb4..3abad47 100644 --- a/Demo.py +++ b/Demo.py @@ -1,5 +1,6 @@ import streamlit as st from couchbase_streamlit_connector.connector import CouchbaseConnector +from couchbase.options import QueryOptions import pandas as pd import plotly.graph_objects as go import plotly.express as px @@ -22,18 +23,12 @@ def get_all_airports(_connection): @st.cache_data def get_routes_for_airports(_connection, selected_airports_df): - airports_faa = "[" - for i in range(len(selected_airports_df)): - if i != len(selected_airports_df) - 1: - airports_faa += f'"{(selected_airports_df.iloc[i])["faa"]}", ' - else: - airports_faa += f'"{(selected_airports_df.iloc[i])["faa"]}"' - airports_faa += "]" + airports_faa = str(selected_airports_df["faa"].to_list()) # Initialize a string to store FAA codes in a list format query = f""" SELECT * FROM `travel-sample`.`inventory`.`route` - WHERE (sourceairport IN {airports_faa} AND destinationairport IN {airports_faa}); + WHERE (sourceairport IN $airports_faa AND destinationairport IN $airports_faa); """ - result = _connection.query(query) + result = _connection.query(query, opts=QueryOptions(named_parameters={"airports_faa": airports_faa})) data = [] for row in result: data.append(row["route"]) @@ -41,11 +36,8 @@ def get_routes_for_airports(_connection, selected_airports_df): def plot_airports_and_routes(airports_df, routes_df): fig = go.Figure() - airport_coords = { - row["faa"]: (row["lat"], row["lon"]) - for _, row in airports_df.iterrows() - if row["faa"] is not None # Ensure faa is not null - } + filtered_airports_df = airports_df.dropna(subset=["faa"]) # Remove rows where faa is NaN + airport_coords = dict(zip(filtered_airports_df["faa"], zip(filtered_airports_df["lat"], filtered_airports_df["lon"]))) lats = [] lons = [] for _, row in routes_df.iterrows(): @@ -75,7 +67,10 @@ def plot_airports_and_routes(airports_df, routes_df): color_discrete_sequence=["red"], # Color of airport markers ) fig.add_traces(airports_markers.data) + fig.update_geos(fitbounds="locations") fig.update_layout( + map_zoom= 0.5, # Zoom level + showlegend= False, # Hide legend mapbox_style="open-street-map", margin=dict(l=0, r=0, t=50, b=0), # Remove extra margins title="Airports and Flight Routes" @@ -202,7 +197,11 @@ def get_hotels_near_landmark(_connection, landmark_lat, landmark_lon, max_distan return hotels def create_landmark_map(landmarks, hotels_near_landmark): - fig = go.Figure() + fig = go.Figure() + + centre = {"lat": 0, "lon": 0} + num_points = 0 + for hotel in hotels_near_landmark: color = 'red' if hotel.get('distance') <= 3 else 'orange' if hotel.get('distance') <= 6 else 'gold' fig.add_trace(go.Scattermap( @@ -216,6 +215,8 @@ def create_landmark_map(landmarks, hotels_near_landmark): hoverinfo='text', name=f'Hotel ({color})' )) + centre = {"lat": centre["lat"] + hotel.get('lat'), "lon": centre["lon"] + hotel.get('lon')} + num_points += 1 for landmark in landmarks: fig.add_trace(go.Scattermap( @@ -229,8 +230,16 @@ def create_landmark_map(landmarks, hotels_near_landmark): hoverinfo='text', name='Landmark' )) + centre = {"lat": centre["lat"] + landmark.get('lat', 0), "lon": centre["lon"] + landmark.get('lon', 0)} + num_points += 1 + + if num_points > 0: + centre = {"lat": centre["lat"] / num_points, "lon": centre["lon"] / num_points} + fig.update_geos(fitbounds="locations") fig.update_layout( + map_zoom=11, + map_center=centre, mapbox_style='open-street-map', margin=dict(l=0, r=0, t=50, b=0), title='Landmarks and Hotels Nearby', @@ -275,22 +284,15 @@ def get_all_cities(_connection): @st.cache_data def get_all_hotels(_connection, cities): - cities_str = "[" - for i in range(len(cities)): - if i != len(cities) - 1: - cities_str += f'"{cities[i]}", ' - else: - cities_str += f'"{cities[i]}"' - cities_str += "]" query = f""" SELECT h.*, geo.lat as lat, geo.lon as lon, ARRAY_AVG(ARRAY r.ratings.Overall FOR r IN h.reviews WHEN r.ratings.Overall IS NOT MISSING END) as avg_rating FROM `travel-sample`.inventory.hotel h WHERE h.geo.lat IS NOT MISSING AND h.type = "hotel" AND h.geo.lon IS NOT MISSING - AND h.city IN {cities_str} + AND h.city IN $cities; """ - result = _connection.query(query) + result = _connection.query(query, opts=QueryOptions(named_parameters={"cities": cities})) hotels = [] for row in result: hotels.append(row) @@ -318,6 +320,10 @@ def create_hotel_map(hotels_df): if 'avg_rating' not in hotels_df.columns: hotels_df['avg_rating'] = np.nan # Add avg_rating column if it doesn't exist hotels_df['avg_rating'] = pd.to_numeric(hotels_df['avg_rating'], errors='coerce') + centre = { + "lat": hotels_df['lat'].mean(), + "lon": hotels_df['lon'].mean() + } # Create a column for star ratings hotels_df['star_rating'] = hotels_df['avg_rating'].apply(lambda x: '⭐' * int(round(x)) if not np.isnan(x) else 'No rating') @@ -363,6 +369,8 @@ def create_hotel_map(hotels_df): fig.add_traces(no_rating_markers.data) fig.update_layout( + map_zoom=10, + map_center=centre, mapbox_style="open-street-map", margin=dict(l=0, r=0, t=50, b=0), title="Hotels (colored by average rating)", @@ -373,18 +381,11 @@ def create_hotel_map(hotels_df): ) ) - fig.update_layout( - mapbox_style="open-street-map", - margin=dict(l=0, r=0, t=50, b=0), - title="Hotels (colored by average rating)", - coloraxis_colorbar_title="Avg Rating" - ) - st.plotly_chart(fig, use_container_width=True) def tab3_visual(): all_cities = get_all_cities(connection)["city"].tolist() - cities = st.multiselect("Select cities", all_cities, default=["Newport", "Birmingham", "London"]) + cities = st.multiselect("Select cities", all_cities, default=["London"]) hotels = get_all_hotels(connection, cities) create_hotel_map(hotels)