33import importlib .util
44import inspect
55import logging
6+ from decimal import Decimal
67from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , TypeVar , Union
78
89import boto3
1617
1718__all__ = ["connect" , "read_sql_query" , "read_sql_table" , "to_sql" ]
1819
19- _cx_Oracle_found = importlib .util .find_spec ("cx_Oracle " )
20- if _cx_Oracle_found :
21- import cx_Oracle # pylint: disable=import-error
20+ _oracledb_found = importlib .util .find_spec ("oracledb " )
21+ if _oracledb_found :
22+ import oracledb # pylint: disable=import-error
2223
2324_logger : logging .Logger = logging .getLogger (__name__ )
2425FuncT = TypeVar ("FuncT" , bound = Callable [..., Any ])
2526
2627
27- def _check_for_cx_Oracle (func : FuncT ) -> FuncT :
28+ def _check_for_oracledb (func : FuncT ) -> FuncT :
2829 def inner (* args : Any , ** kwargs : Any ) -> Any :
29- if not _cx_Oracle_found :
30+ if not _oracledb_found :
3031 raise ModuleNotFoundError (
31- "You need to install cx_Oracle respectively the "
32+ "You need to install oracledb respectively the "
3233 "AWS Data Wrangler package with the `oracle` extra for using the oracle module"
3334 )
3435 return func (* args , ** kwargs )
@@ -39,11 +40,11 @@ def inner(*args: Any, **kwargs: Any) -> Any:
3940 return inner # type: ignore
4041
4142
42- def _validate_connection (con : "cx_Oracle .Connection" ) -> None :
43- if not isinstance (con , cx_Oracle .Connection ):
43+ def _validate_connection (con : "oracledb .Connection" ) -> None :
44+ if not isinstance (con , oracledb .Connection ):
4445 raise exceptions .InvalidConnection (
4546 "Invalid 'conn' argument, please pass a "
46- "cx_Oracle .Connection object. Use cx_Oracle .connect() to use "
47+ "oracledb .Connection object. Use oracledb .connect() to use "
4748 "credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog."
4849 )
4950
@@ -54,7 +55,7 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str:
5455 return table_identifier
5556
5657
57- def _drop_table (cursor : "cx_Oracle .Cursor" , schema : Optional [str ], table : str ) -> None :
58+ def _drop_table (cursor : "oracledb .Cursor" , schema : Optional [str ], table : str ) -> None :
5859 table_identifier = _get_table_identifier (schema , table )
5960 sql = f"""
6061BEGIN
@@ -70,15 +71,15 @@ def _drop_table(cursor: "cx_Oracle.Cursor", schema: Optional[str], table: str) -
7071 cursor .execute (sql )
7172
7273
73- def _does_table_exist (cursor : "cx_Oracle .Cursor" , schema : Optional [str ], table : str ) -> bool :
74+ def _does_table_exist (cursor : "oracledb .Cursor" , schema : Optional [str ], table : str ) -> bool :
7475 schema_str = f"OWNER = '{ schema } ' AND" if schema else ""
7576 cursor .execute (f"SELECT * FROM ALL_TABLES WHERE { schema_str } TABLE_NAME = '{ table } '" )
7677 return len (cursor .fetchall ()) > 0
7778
7879
7980def _create_table (
8081 df : pd .DataFrame ,
81- cursor : "cx_Oracle .Cursor" ,
82+ cursor : "oracledb .Cursor" ,
8283 table : str ,
8384 schema : str ,
8485 mode : str ,
@@ -105,18 +106,18 @@ def _create_table(
105106 cursor .execute (sql )
106107
107108
108- @_check_for_cx_Oracle
109+ @_check_for_oracledb
109110def connect (
110111 connection : Optional [str ] = None ,
111112 secret_id : Optional [str ] = None ,
112113 catalog_id : Optional [str ] = None ,
113114 dbname : Optional [str ] = None ,
114115 boto3_session : Optional [boto3 .Session ] = None ,
115116 call_timeout : Optional [int ] = 0 ,
116- ) -> "cx_Oracle .Connection" :
117- """Return a cx_Oracle connection from a Glue Catalog Connection.
117+ ) -> "oracledb .Connection" :
118+ """Return a oracledb connection from a Glue Catalog Connection.
118119
119- https://github.com/oracle/python-cx_Oracle
120+ https://github.com/oracle/python-oracledb
120121
121122 Note
122123 ----
@@ -148,13 +149,13 @@ def connect(
148149 call_timeout: Optional[int]
149150 This is the time in milliseconds that a single round-trip to the database may take before a timeout will occur.
150151 The default is None which means no timeout.
151- This parameter is forwarded to cx_Oracle .
152+ This parameter is forwarded to oracledb .
152153 https://cx-oracle.readthedocs.io/en/latest/api_manual/connection.html#Connection.call_timeout
153154
154155 Returns
155156 -------
156- cx_Oracle .Connection
157- cx_Oracle connection.
157+ oracledb .Connection
158+ oracledb connection.
158159
159160 Examples
160161 --------
@@ -174,22 +175,22 @@ def connect(
174175 f"Invalid connection type ({ attrs .kind } . It must be an oracle connection.)"
175176 )
176177
177- connection_dsn = cx_Oracle .makedsn (attrs .host , attrs .port , service_name = attrs .database )
178+ connection_dsn = oracledb .makedsn (attrs .host , attrs .port , service_name = attrs .database )
178179 _logger .debug ("DSN: %s" , connection_dsn )
179- oracle_connection = cx_Oracle .connect (
180+ oracle_connection = oracledb .connect (
180181 user = attrs .user ,
181182 password = attrs .password ,
182183 dsn = connection_dsn ,
183184 )
184- # cx_Oracle .connect does not have a call_timeout attribute, it has to be set separatly
185+ # oracledb .connect does not have a call_timeout attribute, it has to be set separatly
185186 oracle_connection .call_timeout = call_timeout
186187 return oracle_connection
187188
188189
189- @_check_for_cx_Oracle
190+ @_check_for_oracledb
190191def read_sql_query (
191192 sql : str ,
192- con : "cx_Oracle .Connection" ,
193+ con : "oracledb .Connection" ,
193194 index_col : Optional [Union [str , List [str ]]] = None ,
194195 params : Optional [Union [List [Any ], Tuple [Any , ...], Dict [Any , Any ]]] = None ,
195196 chunksize : Optional [int ] = None ,
@@ -203,8 +204,8 @@ def read_sql_query(
203204 ----------
204205 sql : str
205206 SQL query.
206- con : cx_Oracle .Connection
207- Use cx_Oracle .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
207+ con : oracledb .Connection
208+ Use oracledb .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
208209 index_col : Union[str, List[str]], optional
209210 Column(s) to set as index(MultiIndex).
210211 params : Union[List, Tuple, Dict], optional
@@ -252,10 +253,10 @@ def read_sql_query(
252253 )
253254
254255
255- @_check_for_cx_Oracle
256+ @_check_for_oracledb
256257def read_sql_table (
257258 table : str ,
258- con : "cx_Oracle .Connection" ,
259+ con : "oracledb .Connection" ,
259260 schema : Optional [str ] = None ,
260261 index_col : Optional [Union [str , List [str ]]] = None ,
261262 params : Optional [Union [List [Any ], Tuple [Any , ...], Dict [Any , Any ]]] = None ,
@@ -270,8 +271,8 @@ def read_sql_table(
270271 ----------
271272 table : str
272273 Table name.
273- con : cx_Oracle .Connection
274- Use cx_Oracle .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
274+ con : oracledb .Connection
275+ Use oracledb .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
275276 schema : str, optional
276277 Name of SQL schema in database to query (if database flavor supports this).
277278 Uses default schema if None (default).
@@ -324,11 +325,11 @@ def read_sql_table(
324325 )
325326
326327
327- @_check_for_cx_Oracle
328+ @_check_for_oracledb
328329@apply_configs
329330def to_sql (
330331 df : pd .DataFrame ,
331- con : "cx_Oracle .Connection" ,
332+ con : "oracledb .Connection" ,
332333 table : str ,
333334 schema : str ,
334335 mode : str = "append" ,
@@ -344,8 +345,8 @@ def to_sql(
344345 ----------
345346 df : pandas.DataFrame
346347 Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
347- con : cx_Oracle .Connection
348- Use cx_Oracle .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
348+ con : oracledb .Connection
349+ Use oracledb .connect() to use credentials directly or wr.oracle.connect() to fetch it from the Glue Catalog.
349350 table : str
350351 Table name
351352 schema : str
@@ -424,3 +425,36 @@ def to_sql(
424425 con .rollback ()
425426 _logger .error (ex )
426427 raise
428+
429+
430+ def detect_oracle_decimal_datatype (cursor : Any ) -> Dict [str , pa .DataType ]:
431+ """Determine if a given Oracle column is a decimal, not just a standard float value."""
432+ dtype = {}
433+ _logger .debug ("cursor type: %s" , type (cursor ))
434+ if isinstance (cursor , oracledb .Cursor ):
435+ # Oracle stores DECIMAL as the NUMBER type
436+ for row in cursor .description :
437+ if row [1 ] == oracledb .DB_TYPE_NUMBER and row [5 ] > 0 :
438+ dtype [row [0 ]] = pa .decimal128 (row [4 ], row [5 ])
439+
440+ _logger .debug ("decimal dtypes: %s" , dtype )
441+ return dtype
442+
443+
444+ def handle_oracle_objects (
445+ col_values : List [Any ], col_name : str , dtype : Optional [Dict [str , pa .DataType ]] = None
446+ ) -> List [Any ]:
447+ """Get the string representation of an Oracle LOB value, and convert float to decimal."""
448+ if any (isinstance (col_value , oracledb .LOB ) for col_value in col_values ):
449+ col_values = [
450+ col_value .read () if isinstance (col_value , oracledb .LOB ) else col_value for col_value in col_values
451+ ]
452+
453+ if dtype is not None :
454+ if isinstance (dtype [col_name ], pa .Decimal128Type ):
455+ _logger .debug ("decimal_col_values:\n %s" , col_values )
456+ col_values = [
457+ Decimal (repr (col_value )) if isinstance (col_value , float ) else col_value for col_value in col_values
458+ ]
459+
460+ return col_values
0 commit comments