diff --git a/libs/text-splitters/langchain_text_splitters/base.py b/libs/text-splitters/langchain_text_splitters/base.py index ed06b78a498a9..5073b2965a6ce 100644 --- a/libs/text-splitters/langchain_text_splitters/base.py +++ b/libs/text-splitters/langchain_text_splitters/base.py @@ -330,6 +330,7 @@ class Language(str, Enum): ELIXIR = "elixir" POWERSHELL = "powershell" VISUALBASIC6 = "visualbasic6" + MYSQL = "mysql" @dataclass(frozen=True) diff --git a/libs/text-splitters/langchain_text_splitters/character.py b/libs/text-splitters/langchain_text_splitters/character.py index 22736f1df58d3..9dad39756ee34 100644 --- a/libs/text-splitters/langchain_text_splitters/character.py +++ b/libs/text-splitters/langchain_text_splitters/character.py @@ -760,6 +760,82 @@ def get_separators_for_language(language: Language) -> list[str]: " ", "", ] + if language == Language.MYSQL: + return [ + # Split along definitions + "\ncreate ", + "\nCREATE ", + "\nalter ", + "\nALTER ", + "\ndrop ", + "\nDROP ", + "\ntruncate ", + "\nTRUNCATE ", + "\nrename ", + "\nRENAME ", + "\nuse ", + "\nUSE ", + "\ndesc ", + "\nDESC ", + "\ndescribe ", + "\nDESCRIBE ", + # Split along Control and procedural code + "\nbegin", + "\nBEGIN", + "\nloop ", + "\nLOOP ", + "\nif ", + "\nIF ", + "\nwhile ", + "\nWHILE ", + "\nelse ", + "\nELSE ", + "\nelseif ", + "\nELSEIF ", + "\nrepeat ", + "\nREPEAT ", + "\nhandler ", + "\nHANDLER ", + # Split along data manipulation + "\nselect ", + "\nSELECT ", + "\ninsert ", + "\nINSERT ", + "\nupdate ", + "\nUPDATE ", + "\ndelete ", + "\nDELETE ", + "\nreplace ", + "\nREPLACE ", + "\nwith ", + "\nWITH ", + "\nshow ", + "\nSHOW ", + "\nexplain ", + "\nEXPLAIN ", + "\ncall ", + "\nCALL ", + # Split along permissions and transactions + "\ngrant ", + "\nGRANT ", + "\nrevoke ", + "\nREVOKE ", + "\ncommit ", + "\nCOMMIT ", + "\nrollback ", + "\nROLLBACK ", + "\nstart transaction", + "\nSTART TRANSACTION", + "\nset autocommit", + "\nSET AUTOCOMMIT", + "\nDELIMITER ", + "\ndelimiter ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] if language in Language._value2member_map_: msg = f"Language {language} is not implemented yet!" diff --git a/libs/text-splitters/tests/unit_tests/test_text_splitters.py b/libs/text-splitters/tests/unit_tests/test_text_splitters.py index 6825dce276573..125e2a61f93c9 100644 --- a/libs/text-splitters/tests/unit_tests/test_text_splitters.py +++ b/libs/text-splitters/tests/unit_tests/test_text_splitters.py @@ -3277,6 +3277,45 @@ def test_visualbasic6_code_splitter() -> None: ] +def test_mysql_code_splitter() -> None: + splitter = RecursiveCharacterTextSplitter.from_language( + Language.MYSQL, + chunk_size=CHUNK_SIZE, + chunk_overlap=0, + ) + code = """ +CREATE TABLE products ( + id INT PRIMARY KEY, + name VARCHAR(100) +); +INSERT INTO products VALUES (1, 'Keyboard'), (2, 'Mouse'); +SELECT * FROM products WHERE id = 1; +SELECT name FROM products ORDER BY name DESC; +""" + chunks = splitter.split_text(code) + assert chunks == [ + "CREATE TABLE", + "products (", + "id INT", + "PRIMARY KEY,", + "name", + "VARCHAR(100)", + ");", + "INSERT INTO", + "products VALUES", + "(1,", + "'Keyboard'),", + "(2, 'Mouse');", + "SELECT * FROM", + "products WHERE", + "id = 1;", + "SELECT name", + "FROM products", + "ORDER BY name", + "DESC;", + ] + + def custom_iframe_extractor(iframe_tag: Tag) -> str: iframe_src = iframe_tag.get("src", "") return f"[iframe:{iframe_src}]({iframe_src})"