@@ -780,6 +780,36 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
780780 ),
781781)
782782
783+
784+ def collatz_step_count (n : int ) -> int :
785+ steps = 0
786+ while n != 1 :
787+ if n % 2 == 0 :
788+ n //= 2
789+ else :
790+ n = 3 * n + 1
791+ steps += 1
792+ return steps
793+
794+
795+ def collatz (inputs : pa .Array ) -> pa .Array :
796+ results = [collatz_step_count (n ) for n in inputs .to_pylist ()]
797+ return pa .array (results , type = pa .int64 ())
798+
799+
800+ def collatz_steps (n : int ) -> list [int ]:
801+ steps = 0
802+ results = []
803+ while n != 1 :
804+ if n % 2 == 0 :
805+ n //= 2
806+ else :
807+ n = 3 * n + 1
808+ results .append (n )
809+ steps += 1
810+ return results
811+
812+
783813util_schema = SchemaCollection (
784814 scalar_functions_by_name = CaseInsensitiveDict (
785815 {
@@ -798,6 +828,18 @@ def yellow_taxi_endpoint_generator(ticket_data: Any) -> list[flight.FlightEndpoi
798828 output_schema = pa .schema ([pa .field ("result" , pa .int64 ())]),
799829 handler = add_handler ,
800830 ),
831+ "collatz" : ScalarFunction (
832+ input_schema = pa .schema ([pa .field ("n" , pa .int64 ())]),
833+ output_schema = pa .schema ([pa .field ("result" , pa .int64 ())]),
834+ handler = lambda table : collatz (table .column (0 )),
835+ ),
836+ "collatz_sequence" : ScalarFunction (
837+ input_schema = pa .schema ([pa .field ("n" , pa .int64 ())]),
838+ output_schema = pa .schema ([pa .field ("result" , pa .list_ (pa .int64 ()))]),
839+ handler = lambda table : pa .array (
840+ [collatz_steps (n ) for n in table .column (0 ).to_pylist ()], type = pa .list_ (pa .int64 ())
841+ ),
842+ ),
801843 }
802844 ),
803845 table_functions_by_name = CaseInsensitiveDict (
0 commit comments