1212from helperFunctions .helperFunction import *
1313from pyspark .sql .functions import *
1414from pyspark .sql .types import FloatType , IntegerType , StringType
15+ import mlflow
16+ from mlflow .tracking import MlflowClient
17+
18+ def wait_until_ready (model_name , model_version , client ):
19+ for _ in range (10 ):
20+ model_version_details = client .get_model_version (
21+ name = model_name ,
22+ version = model_version ,
23+ )
24+ status = ModelVersionStatus .from_string (model_version_details .status )
25+ print ("Model status: %s" % ModelVersionStatus .to_string (status ))
26+ if status == ModelVersionStatus .READY :
27+ break
28+ time .sleep (1 )
29+
1530
16- def evaluation (fs , taxi_data , model_name ):
31+
32+ def evaluation (fs , taxi_data , model_name , model , training_set , run_id , client ):
1733 taxi_data = rounded_taxi_data (taxi_data )
1834
1935 cols = ['fare_amount' , 'trip_distance' , 'pickup_zip' , 'dropoff_zip' , 'rounded_pickup_datetime' , 'rounded_dropoff_datetime' ]
2036 taxi_data_reordered = taxi_data .select (cols )
2137 display (taxi_data_reordered )
2238
2339
40+ # If no model currently exists in production stage, simply register the model, and promote it the production stage
41+ model_stage = "production"
42+ model_uri = "models:/{model_name}/{model_stage}" .format (model_name = model_name , model_stage = model_stage )
43+
44+ if not model_uri :
45+
46+ # MOVE TO REGISTRATION #############################################################################################
47+
48+ #artifact_path = "model"
49+ #model_uri = "runs:/{run_id}/{artifact_path}".format(run_id=run_id, artifact_path=artifact_path)
50+
51+ latest_model_version = get_latest_model_version (model_name )
52+ model_uri = f"models:/taxi_example_fare_packaged/{ latest_model_version } "
53+
54+ model_details = mlflow .register_model (model_uri = model_uri , name = model_name )
55+
56+ # wait until the reigstered model is ready
57+ wait_until_ready (model_details .name , model_details .version , client )
58+
59+ client .update_registered_model (
60+ name = model_details .name ,
61+ description = "Insert"
62+ )
63+
64+ client .update_model_version (
65+ name = model_details .name ,
66+ version = model_details .version ,
67+ description = "Insert"
68+ )
69+ #############################################################################################################################
70+
71+ else :
72+
73+ # Score Production - MOVING TO SCORE - #############################################################################################
74+ model_stage = "production"
75+ model_uri = "models:/{model_name}/{model_stage}" .format (model_name = model_name , model_stage = model_stage )
76+ with_predictions = fs .score_batch (model_uri , taxi_data )
77+
78+ # Get Latest Version: Which is the the model you have just trained
79+ latest_model_version = get_latest_model_version (model_name )
80+ model_uri = "models:/{model_name}/{latest_model_version}" .format (model_name = model_name , latest_model_version = latest_model_version )
81+ with_predictions = fs .score_batch (model_uri , taxi_data )
82+
83+
84+ import pyspark .sql .functions as func
85+ cols = ['prediction' , 'fare_amount' , 'trip_distance' , 'pickup_zip' , 'dropoff_zip' ,
86+ 'rounded_pickup_datetime' , 'rounded_dropoff_datetime' , 'mean_fare_window_1h_pickup_zip' ,
87+ 'count_trips_window_1h_pickup_zip' , 'count_trips_window_30m_dropoff_zip' , 'dropoff_is_weekend' ]
88+
89+ with_predictions_reordered = (
90+ with_predictions .select (
91+ cols ,
92+ )
93+ .withColumnRenamed (
94+ "prediction" ,
95+ "predicted_fare_amount" ,
96+ )
97+ .withColumn (
98+ "predicted_fare_amount" ,
99+ func .round ("predicted_fare_amount" , 2 ),
100+ )
101+ )
102+ display (with_predictions_reordered )
103+ # Get the R2 etc. ####################################################################################################################
104+
105+
106+
107+
108+
109+ # CREATE LOGIC DEFINING WHEN TO PROMOTE MODEL (EVALUATION)
110+ is_improvement = True
111+ ##########################################################
112+
113+ if is_improvement :
114+
115+ # MOVE TO "REGISTRATION" SCRIPTS - CALL FUNCTION FROM HERE
116+ model_details = mlflow .register_model (model_uri = model_uri , name = model_name )
117+
118+ # wait until the reigstered model is ready
119+ wait_until_ready (model_details .name , model_details .version , client )
120+
121+ client .update_registered_model (
122+ name = model_details .name ,
123+ description = "Insert"
124+ )
125+
126+ client .update_model_version (
127+ name = model_details .name ,
128+ version = model_details .version ,
129+ description = "Insert"
130+ )
131+
132+ # Demote Staging to None
133+ staging_stage = 'staging'
134+ no_stage = None
135+ # Get the latest model version in the staging stage
136+ latest_production_version = mlflow .get_latest_versions (
137+ name = model_name ,
138+ stages = [staging_stage ],
139+ order_by = ['creation_time desc' ],
140+ max_results = 1
141+ )[0 ].version
142+
143+ mlflow .transition_model_version_stage (
144+ name = model_name ,
145+ version = latest_production_version ,
146+ stage = no_stage
147+ )
148+
149+
150+ # Demote Production To Staging (Keeps Incumbent Model As A BackStop)
151+ production_stage = 'production'
152+ staging_stage = 'staging'
153+ # Get the latest model version in the production stage
154+ latest_production_version = mlflow .get_latest_versions (
155+ name = model_name ,
156+ stages = [production_stage ],
157+ order_by = ['creation_time desc' ],
158+ max_results = 1
159+ )[0 ].version
160+
161+ # Demote the latest model version from production to staging
162+ mlflow .transition_model_version_stage (
163+ name = model_name ,
164+ version = latest_production_version ,
165+ stage = staging_stage
166+ )
167+
168+
169+ # Get latest registered model. This is the challenger that will be promoted to Production
170+ latest_registered_version = mlflow .get_latest_versions (
171+ name = model_name ,
172+ order_by = ['creation_time desc' ],
173+ max_results = 1
174+ )[0 ].version
175+
176+ mlflow .transition_model_version_stage (
177+ name = model_name ,
178+ version = latest_registered_version ,
179+ stage = production_stage
180+ )
181+
182+
183+
24184 # Get the model URI
25185 latest_model_version = get_latest_model_version (model_name )
26186 model_uri = f"models:/taxi_example_fare_packaged/{ latest_model_version } "
27-
187+ with_predictions = fs . score_batch ( model_uri , taxi_data )
28188 #If there is no model registered with this name, then register it, and promote it to production.
29189
30190 # If there is a model that is registered and in productionstage , then 1. load it, 2. score it.
31191 # 3. Load model that you've just logged. compare the results.
32192 # 4. If better then promote most recent version of model to production stage, and demote current production to stage
33193
34194
35-
36-
37-
38- with_predictions = fs .score_batch (model_uri , taxi_data )
39-
40- print ()
41-
42-
43195 # COMMAND ----------
44- latest_pyfunc_version = get_latest_model_version ("pyfunc_taxi_fare_packaged" )
45- pyfunc_model_uri = f"models:/pyfunc_taxi_fare_packaged/{ latest_pyfunc_version } "
46- pyfunc_predictions = fs .score_batch (pyfunc_model_uri ,
47- taxi_data ,
48- result_type = 'string' )
196+ # latest_pyfunc_version = get_latest_model_version("pyfunc_taxi_fare_packaged")
197+ # pyfunc_model_uri = f"models:/pyfunc_taxi_fare_packaged/{latest_pyfunc_version}"
198+ # pyfunc_predictions = fs.score_batch(pyfunc_model_uri,
199+ # taxi_data,
200+ # result_type='string')
49201
50202
51203 # COMMAND ----------
52204 import pyspark .sql .functions as func
53205 cols = ['prediction' , 'fare_amount' , 'trip_distance' , 'pickup_zip' , 'dropoff_zip' ,
54206 'rounded_pickup_datetime' , 'rounded_dropoff_datetime' , 'mean_fare_window_1h_pickup_zip' ,
55207 'count_trips_window_1h_pickup_zip' , 'count_trips_window_30m_dropoff_zip' , 'dropoff_is_weekend' ]
56-
57208 with_predictions_reordered = (
58209 with_predictions .select (
59210 cols ,
@@ -67,19 +218,25 @@ def evaluation(fs, taxi_data, model_name):
67218 func .round ("predicted_fare_amount" , 2 ),
68219 )
69220 )
70-
71221 display (with_predictions_reordered )
72222
73223 # COMMAND ----------
74- display (pyfunc_predictions .select ('fare_amount' , 'prediction' ))
224+ # display(pyfunc_predictions.select('fare_amount', 'prediction'))
75225
76226 # COMMAND ----------
77227
78228if __name__ == "__main__" :
79229 fs = feature_store .FeatureStoreClient ()
80230 model_name = "taxi_example_fare_packaged"
81231 taxi_data = spark .read .table ("feature_store_taxi_example.nyc_yellow_taxi_with_zips" )
82- eval (fs = fs , taxi_data = taxi_data , model_name = model_name )
232+ run_id = mlflow .active_run ().info .run_id
233+
234+ # Do not log
235+
236+ # training_set will cbe returned from another function.
237+ training_set = []
238+ client = MlflowClient ()
239+ evaluation (fs = fs , taxi_data = taxi_data , model_name = model_name , training_set = training_set , run_id = run_id , client = client )
83240
84241
85242
0 commit comments