Skip to content

Commit 9ae4b07

Browse files
committed
add user import as well
1 parent 8bd7b6d commit 9ae4b07

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

cogs/admin.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""User management"""
22

3+
import datetime
4+
import io
35
import logging
46
from typing import Optional
57
import discord
68
from discord import Role, app_commands
79
from discord.ext import commands
810
from sqlalchemy import select, func, and_
911
from sqlalchemy.exc import IntegrityError
10-
12+
import csv
1113
from bot import KDRBot
1214
from database.dto.kd_roles import KDRole
1315
from database.dto.users import User
@@ -16,7 +18,7 @@
1618
from utils.register import register
1719
from utils.role_management import RoleManagement
1820
from utils.server_settings import get_guild, update_guild
19-
from utils.users import get_users_csv
21+
from utils.users import get_users_csv, import_csv
2022
from utils.voice_channel import create_voice_channel
2123
from utils.voice_channels import add_voice_channel
2224

@@ -101,7 +103,29 @@ async def export_registered_users(self, interaction: discord.Interaction) -> Non
101103
await interaction.followup.send("There are currently no registered users", ephemeral=True)
102104
return
103105
await interaction.followup.send("Registered users:", ephemeral=True, file=registered_users)
104-
106+
107+
@group.command(name="import", description="Import registered users")
108+
@app_commands.guild_only()
109+
@app_commands.default_permissions(administrator=True)
110+
@app_commands.checks.has_permissions(administrator=True)
111+
async def import_registered_users(self, interaction: discord.Interaction, file: discord.Attachment) -> None:
112+
"""Import registered users"""
113+
await interaction.response.defer()
114+
if interaction.guild is None:
115+
return # is already set to guild_only
116+
async with self.bot.db.create_session() as session:
117+
try:
118+
await import_csv(session, interaction.guild_id, file)
119+
await interaction.followup.send(
120+
"Users have been imported!",
121+
ephemeral=True,
122+
)
123+
except Exception as e:
124+
await interaction.followup.send(
125+
"Something went wrong during the import, did you deliver the data in the same order as the export function?",
126+
ephemeral=True,
127+
)
128+
print(e)
105129

106130
@group.command(name="unregister", description="Unregister a user")
107131
@app_commands.describe(username="EA username")

database/dto/users.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ class User(Base):
2424
UniqueConstraint(server_id, discord_id, player_id, user_id),
2525
) # must be a tuple!
2626

27+
def import_user(self, row, header):
28+
import_data = dict(zip(header, row))
29+
self.server_id = int(import_data.get("server_id"))
30+
self.discord_id = int(import_data.get("discord_id"))
31+
self.username = import_data.get("username")
32+
self.user_id = int(import_data.get("user_id"))
33+
return self
34+
35+
2736
async def update_kdr(self, session: AsyncSession, kdr_role_id: int | None):
2837
stmt = (
2938
update(User)

utils/users.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import csv
2+
import datetime
23
import io
34

45
import discord
56
from sqlalchemy import select
7+
from sqlalchemy.exc import IntegrityError
68
from sqlalchemy.ext.asyncio import AsyncSession
79

810
from database.dto.users import User
11+
from database.error_handling import is_unique_violation
912

1013
async def get_users_csv(session: AsyncSession, server_id: int) -> tuple[int, discord.File]:
1114
stmt = (
@@ -20,4 +23,24 @@ async def get_users_csv(session: AsyncSession, server_id: int) -> tuple[int, dis
2023
for row in res:
2124
outcsv.writerow([row[0].server_id, row[0].discord_id, row[0].username, row[0].player_id, row[0].kdr_role_id, row[0].user_id, row[0].created_at, row[0].updated_at])
2225
data_stream.seek(0)
23-
return (total, discord.File(data_stream, filename="channel_names.csv"))
26+
return (total, discord.File(data_stream, filename="channel_names.csv"))
27+
28+
async def import_csv(session: AsyncSession, server_id: int, file: discord.Attachment):
29+
if file.filename.endswith(".csv"):
30+
res = await file.read()
31+
test = io.StringIO(res.decode("utf-8"))
32+
data = list(csv.reader(test, delimiter=','))
33+
for i in data[1:]:
34+
try:
35+
user = User().import_user(i, data[0])
36+
user.server_id = server_id
37+
user.created_at = datetime.datetime.now()
38+
user.updated_at = datetime.datetime.now()
39+
session.add(user)
40+
await session.commit()
41+
except IntegrityError as ex:
42+
await session.rollback()
43+
if is_unique_violation(ex):
44+
continue # user already exists
45+
else:
46+
print(ex)

0 commit comments

Comments
 (0)